diff --git a/.gitignore b/.gitignore index 62db80cd9b4d..101ddbbf2631 100644 --- a/.gitignore +++ b/.gitignore @@ -48,7 +48,9 @@ Vagrantfile # Configs *.hcl !command/agent/config/test-fixtures/config.hcl +!command/agent/config/test-fixtures/config-cache.hcl !command/agent/config/test-fixtures/config-embedded-type.hcl +!command/agent/config/test-fixtures/config-cache-embedded-type.hcl .DS_Store .idea diff --git a/api/client.go b/api/client.go index 80ccd7d50290..432624dd0379 100644 --- a/api/client.go +++ b/api/client.go @@ -25,6 +25,7 @@ import ( "golang.org/x/time/rate" ) +const EnvVaultAgentAddress = "VAULT_AGENT_ADDR" const EnvVaultAddress = "VAULT_ADDR" const EnvVaultCACert = "VAULT_CACERT" const EnvVaultCAPath = "VAULT_CAPATH" @@ -237,6 +238,10 @@ func (c *Config) ReadEnvironment() error { if v := os.Getenv(EnvVaultAddress); v != "" { envAddress = v } + // Agent's address will take precedence over Vault's address + if v := os.Getenv(EnvVaultAgentAddress); v != "" { + envAddress = v + } if v := os.Getenv(EnvVaultMaxRetries); v != "" { maxRetries, err := strconv.ParseUint(v, 10, 32) if err != nil { @@ -366,6 +371,21 @@ func NewClient(c *Config) (*Client, error) { c.modifyLock.Lock() defer c.modifyLock.Unlock() + // If address begins with a `unix://`, treat it as a socket file path and set + // the HttpClient's transport to the corresponding socket dialer. + if strings.HasPrefix(c.Address, "unix://") { + socketFilePath := strings.TrimPrefix(c.Address, "unix://") + c.HttpClient = &http.Client{ + Transport: &http.Transport{ + DialContext: func(context.Context, string, string) (net.Conn, error) { + return net.Dial("unix", socketFilePath) + }, + }, + } + // Set the unix address for URL parsing below + c.Address = "http://unix" + } + u, err := url.Parse(c.Address) if err != nil { return nil, err @@ -707,7 +727,7 @@ func (c *Client) RawRequestWithContext(ctx context.Context, r *Request) (*Respon redirectCount := 0 START: - req, err := r.toRetryableHTTP() + req, err := r.ToRetryableHTTP() if err != nil { return nil, err } diff --git a/api/request.go b/api/request.go index 4efa2aa84177..41d45720fea7 100644 --- a/api/request.go +++ b/api/request.go @@ -62,7 +62,7 @@ func (r *Request) ResetJSONBody() error { // DEPRECATED: ToHTTP turns this request into a valid *http.Request for use // with the net/http package. func (r *Request) ToHTTP() (*http.Request, error) { - req, err := r.toRetryableHTTP() + req, err := r.ToRetryableHTTP() if err != nil { return nil, err } @@ -85,7 +85,7 @@ func (r *Request) ToHTTP() (*http.Request, error) { return req.Request, nil } -func (r *Request) toRetryableHTTP() (*retryablehttp.Request, error) { +func (r *Request) ToRetryableHTTP() (*retryablehttp.Request, error) { // Encode the query parameters r.URL.RawQuery = r.Params.Encode() diff --git a/api/secret.go b/api/secret.go index e25962604b4e..c8a0ba3d9d2c 100644 --- a/api/secret.go +++ b/api/secret.go @@ -292,6 +292,7 @@ type SecretAuth struct { TokenPolicies []string `json:"token_policies"` IdentityPolicies []string `json:"identity_policies"` Metadata map[string]string `json:"metadata"` + Orphan bool `json:"orphan"` LeaseDuration int `json:"lease_duration"` Renewable bool `json:"renewable"` diff --git a/command/agent.go b/command/agent.go index 92c93c70c2e1..4fb8f09f2a31 100644 --- a/command/agent.go +++ b/command/agent.go @@ -4,6 +4,10 @@ import ( "context" "fmt" "io" + "net" + "net/http" + "time" + "os" "sort" "strings" @@ -23,6 +27,7 @@ import ( "github.com/hashicorp/vault/command/agent/auth/gcp" "github.com/hashicorp/vault/command/agent/auth/jwt" "github.com/hashicorp/vault/command/agent/auth/kubernetes" + "github.com/hashicorp/vault/command/agent/cache" "github.com/hashicorp/vault/command/agent/config" "github.com/hashicorp/vault/command/agent/sink" "github.com/hashicorp/vault/command/agent/sink/file" @@ -218,19 +223,6 @@ func (c *AgentCommand) Run(args []string) int { info["cgo"] = "enabled" } - // Server configuration output - padding := 24 - sort.Strings(infoKeys) - c.UI.Output("==> Vault agent configuration:\n") - for _, k := range infoKeys { - c.UI.Output(fmt.Sprintf( - "%s%s: %s", - strings.Repeat(" ", padding-len(k)), - strings.Title(k), - info[k])) - } - c.UI.Output("") - // Tests might not want to start a vault server and just want to verify // the configuration. if c.flagTestVerifyOnly { @@ -332,10 +324,92 @@ func (c *AgentCommand) Run(args []string) int { EnableReauthOnNewCredentials: config.AutoAuth.EnableReauthOnNewCredentials, }) - // Start things running + // Start auto-auth and sink servers go ah.Run(ctx, method) go ss.Run(ctx, ah.OutputCh, sinks) + // Parse agent listener configurations + if config.Cache != nil && len(config.Cache.Listeners) != 0 { + cacheLogger := c.logger.Named("cache") + + // Create the API proxier + apiProxy := cache.NewAPIProxy(&cache.APIProxyConfig{ + Logger: cacheLogger.Named("apiproxy"), + }) + + // Create the lease cache proxier and set its underlying proxier to + // the API proxier. + leaseCache, err := cache.NewLeaseCache(&cache.LeaseCacheConfig{ + BaseContext: ctx, + Proxier: apiProxy, + Logger: cacheLogger.Named("leasecache"), + }) + if err != nil { + c.UI.Error(fmt.Sprintf("Error creating lease cache: %v", err)) + return 1 + } + + // Create a muxer and add paths relevant for the lease cache layer + mux := http.NewServeMux() + mux.Handle("/v1/agent/cache-clear", leaseCache.HandleCacheClear(ctx)) + + mux.Handle("/", cache.Handler(ctx, cacheLogger, leaseCache, config.Cache.UseAutoAuthToken, c.client)) + + var listeners []net.Listener + for i, lnConfig := range config.Cache.Listeners { + listener, props, _, err := cache.ServerListener(lnConfig, c.logWriter, c.UI) + if err != nil { + c.UI.Error(fmt.Sprintf("Error parsing listener configuration: %v", err)) + return 1 + } + + listeners = append(listeners, listener) + + scheme := "https://" + if props["tls"] == "disabled" { + scheme = "http://" + } + if lnConfig.Type == "unix" { + scheme = "unix://" + } + + infoKey := fmt.Sprintf("api address %d", i+1) + info[infoKey] = scheme + listener.Addr().String() + infoKeys = append(infoKeys, infoKey) + + cacheLogger.Info("starting listener", "addr", listener.Addr().String()) + server := &http.Server{ + Handler: mux, + ReadHeaderTimeout: 10 * time.Second, + ReadTimeout: 30 * time.Second, + IdleTimeout: 5 * time.Minute, + ErrorLog: cacheLogger.StandardLogger(nil), + } + go server.Serve(listener) + } + + // Ensure that listeners are closed at all the exits + listenerCloseFunc := func() { + for _, ln := range listeners { + ln.Close() + } + } + defer c.cleanupGuard.Do(listenerCloseFunc) + } + + // Server configuration output + padding := 24 + sort.Strings(infoKeys) + c.UI.Output("==> Vault agent configuration:\n") + for _, k := range infoKeys { + c.UI.Output(fmt.Sprintf( + "%s%s: %s", + strings.Repeat(" ", padding-len(k)), + strings.Title(k), + info[k])) + } + c.UI.Output("") + // Release the log gate. c.logGate.Flush() diff --git a/command/agent/cache/api_proxy.go b/command/agent/cache/api_proxy.go new file mode 100644 index 000000000000..43469a8ce369 --- /dev/null +++ b/command/agent/cache/api_proxy.go @@ -0,0 +1,61 @@ +package cache + +import ( + "bytes" + "context" + "io/ioutil" + + hclog "github.com/hashicorp/go-hclog" + "github.com/hashicorp/vault/api" +) + +// APIProxy is an implementation of the proxier interface that is used to +// forward the request to Vault and get the response. +type APIProxy struct { + logger hclog.Logger +} + +type APIProxyConfig struct { + Logger hclog.Logger +} + +func NewAPIProxy(config *APIProxyConfig) Proxier { + return &APIProxy{ + logger: config.Logger, + } +} + +func (ap *APIProxy) Send(ctx context.Context, req *SendRequest) (*SendResponse, error) { + client, err := api.NewClient(api.DefaultConfig()) + if err != nil { + return nil, err + } + client.SetToken(req.Token) + client.SetHeaders(req.Request.Header) + + fwReq := client.NewRequest(req.Request.Method, req.Request.URL.Path) + fwReq.BodyBytes = req.RequestBody + + // Make the request to Vault and get the response + ap.logger.Info("forwarding request", "path", req.Request.URL.Path, "method", req.Request.Method) + resp, err := client.RawRequestWithContext(ctx, fwReq) + if err != nil { + return nil, err + } + + // Parse and reset response body + respBody, err := ioutil.ReadAll(resp.Body) + if err != nil { + ap.logger.Error("failed to read request body", "error", err) + return nil, err + } + if resp.Body != nil { + resp.Body.Close() + } + resp.Body = ioutil.NopCloser(bytes.NewBuffer(respBody)) + + return &SendResponse{ + Response: resp, + ResponseBody: respBody, + }, nil +} diff --git a/command/agent/cache/api_proxy_test.go b/command/agent/cache/api_proxy_test.go new file mode 100644 index 000000000000..9a68acd36d31 --- /dev/null +++ b/command/agent/cache/api_proxy_test.go @@ -0,0 +1,43 @@ +package cache + +import ( + "testing" + + hclog "github.com/hashicorp/go-hclog" + "github.com/hashicorp/vault/api" + "github.com/hashicorp/vault/helper/jsonutil" + "github.com/hashicorp/vault/helper/logging" + "github.com/hashicorp/vault/helper/namespace" +) + +func TestCache_APIProxy(t *testing.T) { + cleanup, client, _, _ := setupClusterAndAgent(namespace.RootContext(nil), t, nil) + defer cleanup() + + proxier := NewAPIProxy(&APIProxyConfig{ + Logger: logging.NewVaultLogger(hclog.Trace), + }) + + r := client.NewRequest("GET", "/v1/sys/health") + req, err := r.ToRetryableHTTP() + if err != nil { + t.Fatal(err) + } + + resp, err := proxier.Send(namespace.RootContext(nil), &SendRequest{ + Request: req.Request, + }) + if err != nil { + t.Fatal(err) + } + + var result api.HealthResponse + err = jsonutil.DecodeJSONFromReader(resp.Response.Body, &result) + if err != nil { + t.Fatal(err) + } + + if !result.Initialized || result.Sealed || result.Standby { + t.Fatalf("bad sys/health response") + } +} diff --git a/command/agent/cache/cache_test.go b/command/agent/cache/cache_test.go new file mode 100644 index 000000000000..34f6b4b853f7 --- /dev/null +++ b/command/agent/cache/cache_test.go @@ -0,0 +1,926 @@ +package cache + +import ( + "context" + "fmt" + "net" + "net/http" + "os" + "testing" + "time" + + "github.com/hashicorp/vault/logical" + + "github.com/go-test/deep" + hclog "github.com/hashicorp/go-hclog" + kv "github.com/hashicorp/vault-plugin-secrets-kv" + "github.com/hashicorp/vault/api" + "github.com/hashicorp/vault/builtin/credential/userpass" + "github.com/hashicorp/vault/helper/logging" + "github.com/hashicorp/vault/helper/namespace" + vaulthttp "github.com/hashicorp/vault/http" + "github.com/hashicorp/vault/vault" +) + +const policyAdmin = ` +path "*" { + capabilities = ["sudo", "create", "read", "update", "delete", "list"] +} +` + +// setupClusterAndAgent is a helper func used to set up a test cluster and +// caching agent. It returns a cleanup func that should be deferred immediately +// along with two clients, one for direct cluster communication and another to +// talk to the caching agent. +func setupClusterAndAgent(ctx context.Context, t *testing.T, coreConfig *vault.CoreConfig) (func(), *api.Client, *api.Client, *LeaseCache) { + t.Helper() + + if ctx == nil { + ctx = context.Background() + } + + // Handle sane defaults + if coreConfig == nil { + coreConfig = &vault.CoreConfig{ + DisableMlock: true, + DisableCache: true, + Logger: logging.NewVaultLogger(hclog.Trace), + CredentialBackends: map[string]logical.Factory{ + "userpass": userpass.Factory, + }, + } + } + + if coreConfig.CredentialBackends == nil { + coreConfig.CredentialBackends = map[string]logical.Factory{ + "userpass": userpass.Factory, + } + } + + // Init new test cluster + cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{ + HandlerFunc: vaulthttp.Handler, + }) + cluster.Start() + + cores := cluster.Cores + vault.TestWaitActive(t, cores[0].Core) + + // clusterClient is the client that is used to talk directly to the cluster. + clusterClient := cores[0].Client + + // Add an admin policy + if err := clusterClient.Sys().PutPolicy("admin", policyAdmin); err != nil { + t.Fatal(err) + } + + // Set up the userpass auth backend and an admin user. Used for getting a token + // for the agent later down in this func. + clusterClient.Sys().EnableAuthWithOptions("userpass", &api.EnableAuthOptions{ + Type: "userpass", + }) + + _, err := clusterClient.Logical().Write("auth/userpass/users/foo", map[string]interface{}{ + "password": "bar", + "policies": []string{"admin"}, + }) + if err != nil { + t.Fatal(err) + } + + // Set up env vars for agent consumption + origEnvVaultAddress := os.Getenv(api.EnvVaultAddress) + os.Setenv(api.EnvVaultAddress, clusterClient.Address()) + + origEnvVaultCACert := os.Getenv(api.EnvVaultCACert) + os.Setenv(api.EnvVaultCACert, fmt.Sprintf("%s/ca_cert.pem", cluster.TempDir)) + + cacheLogger := logging.NewVaultLogger(hclog.Trace).Named("cache") + + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + + // Create the API proxier + apiProxy := NewAPIProxy(&APIProxyConfig{ + Logger: cacheLogger.Named("apiproxy"), + }) + + // Create the lease cache proxier and set its underlying proxier to + // the API proxier. + leaseCache, err := NewLeaseCache(&LeaseCacheConfig{ + BaseContext: ctx, + Proxier: apiProxy, + Logger: cacheLogger.Named("leasecache"), + }) + if err != nil { + t.Fatal(err) + } + + // Create a muxer and add paths relevant for the lease cache layer + mux := http.NewServeMux() + mux.Handle("/v1/agent/cache-clear", leaseCache.HandleCacheClear(ctx)) + + mux.Handle("/", Handler(ctx, cacheLogger, leaseCache, false, clusterClient)) + server := &http.Server{ + Handler: mux, + ReadHeaderTimeout: 10 * time.Second, + ReadTimeout: 30 * time.Second, + IdleTimeout: 5 * time.Minute, + ErrorLog: cacheLogger.StandardLogger(nil), + } + go server.Serve(listener) + + // testClient is the client that is used to talk to the agent for proxying/caching behavior. + testClient, err := clusterClient.Clone() + if err != nil { + t.Fatal(err) + } + + if err := testClient.SetAddress("http://" + listener.Addr().String()); err != nil { + t.Fatal(err) + } + + // Login via userpass method to derive a managed token. Set that token as the + // testClient's token + resp, err := testClient.Logical().Write("auth/userpass/login/foo", map[string]interface{}{ + "password": "bar", + }) + if err != nil { + t.Fatal(err) + } + testClient.SetToken(resp.Auth.ClientToken) + + cleanup := func() { + cluster.Cleanup() + os.Setenv(api.EnvVaultAddress, origEnvVaultAddress) + os.Setenv(api.EnvVaultCACert, origEnvVaultCACert) + listener.Close() + } + + return cleanup, clusterClient, testClient, leaseCache +} + +func tokenRevocationValidation(t *testing.T, sampleSpace map[string]string, expected map[string]string, leaseCache *LeaseCache) { + t.Helper() + for val, valType := range sampleSpace { + index, err := leaseCache.db.Get(valType, val) + if err != nil { + t.Fatal(err) + } + if expected[val] == "" && index != nil { + t.Fatalf("failed to evict index from the cache: type: %q, value: %q", valType, val) + } + if expected[val] != "" && index == nil { + t.Fatalf("evicted an undesired index from cache: type: %q, value: %q", valType, val) + } + } +} + +func TestCache_TokenRevocations_RevokeOrphan(t *testing.T) { + coreConfig := &vault.CoreConfig{ + DisableMlock: true, + DisableCache: true, + Logger: hclog.NewNullLogger(), + LogicalBackends: map[string]logical.Factory{ + "kv": vault.LeasedPassthroughBackendFactory, + }, + } + + sampleSpace := make(map[string]string) + + cleanup, _, testClient, leaseCache := setupClusterAndAgent(namespace.RootContext(nil), t, coreConfig) + defer cleanup() + + token1 := testClient.Token() + sampleSpace[token1] = "token" + + // Mount the kv backend + err := testClient.Sys().Mount("kv", &api.MountInput{ + Type: "kv", + }) + if err != nil { + t.Fatal(err) + } + + // Create a secret in the backend + _, err = testClient.Logical().Write("kv/foo", map[string]interface{}{ + "value": "bar", + "ttl": "1h", + }) + if err != nil { + t.Fatal(err) + } + + // Read the secret and create a lease + leaseResp, err := testClient.Logical().Read("kv/foo") + if err != nil { + t.Fatal(err) + } + lease1 := leaseResp.LeaseID + sampleSpace[lease1] = "lease" + + resp, err := testClient.Logical().Write("auth/token/create", nil) + if err != nil { + t.Fatal(err) + } + token2 := resp.Auth.ClientToken + sampleSpace[token2] = "token" + + testClient.SetToken(token2) + + leaseResp, err = testClient.Logical().Read("kv/foo") + if err != nil { + t.Fatal(err) + } + lease2 := leaseResp.LeaseID + sampleSpace[lease2] = "lease" + + resp, err = testClient.Logical().Write("auth/token/create", nil) + if err != nil { + t.Fatal(err) + } + token3 := resp.Auth.ClientToken + sampleSpace[token3] = "token" + + testClient.SetToken(token3) + + leaseResp, err = testClient.Logical().Read("kv/foo") + if err != nil { + t.Fatal(err) + } + lease3 := leaseResp.LeaseID + sampleSpace[lease3] = "lease" + + expected := make(map[string]string) + for k, v := range sampleSpace { + expected[k] = v + } + tokenRevocationValidation(t, sampleSpace, expected, leaseCache) + + // Revoke-orphan the intermediate token. This should result in its own + // eviction and evictions of the revoked token's leases. All other things + // including the child tokens and leases of the child tokens should be + // untouched. + testClient.SetToken(token2) + err = testClient.Auth().Token().RevokeOrphan(token2) + if err != nil { + t.Fatal(err) + } + time.Sleep(1 * time.Second) + + expected = map[string]string{ + token1: "token", + lease1: "lease", + token3: "token", + lease3: "lease", + } + tokenRevocationValidation(t, sampleSpace, expected, leaseCache) +} + +func TestCache_TokenRevocations_LeafLevelToken(t *testing.T) { + coreConfig := &vault.CoreConfig{ + DisableMlock: true, + DisableCache: true, + Logger: hclog.NewNullLogger(), + LogicalBackends: map[string]logical.Factory{ + "kv": vault.LeasedPassthroughBackendFactory, + }, + } + + sampleSpace := make(map[string]string) + + cleanup, _, testClient, leaseCache := setupClusterAndAgent(namespace.RootContext(nil), t, coreConfig) + defer cleanup() + + token1 := testClient.Token() + sampleSpace[token1] = "token" + + // Mount the kv backend + err := testClient.Sys().Mount("kv", &api.MountInput{ + Type: "kv", + }) + if err != nil { + t.Fatal(err) + } + + // Create a secret in the backend + _, err = testClient.Logical().Write("kv/foo", map[string]interface{}{ + "value": "bar", + "ttl": "1h", + }) + if err != nil { + t.Fatal(err) + } + + // Read the secret and create a lease + leaseResp, err := testClient.Logical().Read("kv/foo") + if err != nil { + t.Fatal(err) + } + lease1 := leaseResp.LeaseID + sampleSpace[lease1] = "lease" + + resp, err := testClient.Logical().Write("auth/token/create", nil) + if err != nil { + t.Fatal(err) + } + token2 := resp.Auth.ClientToken + sampleSpace[token2] = "token" + + testClient.SetToken(token2) + + leaseResp, err = testClient.Logical().Read("kv/foo") + if err != nil { + t.Fatal(err) + } + lease2 := leaseResp.LeaseID + sampleSpace[lease2] = "lease" + + resp, err = testClient.Logical().Write("auth/token/create", nil) + if err != nil { + t.Fatal(err) + } + token3 := resp.Auth.ClientToken + sampleSpace[token3] = "token" + + testClient.SetToken(token3) + + leaseResp, err = testClient.Logical().Read("kv/foo") + if err != nil { + t.Fatal(err) + } + lease3 := leaseResp.LeaseID + sampleSpace[lease3] = "lease" + + expected := make(map[string]string) + for k, v := range sampleSpace { + expected[k] = v + } + tokenRevocationValidation(t, sampleSpace, expected, leaseCache) + + // Revoke the lef token. This should evict all the leases belonging to this + // token, evict entries for all the child tokens and their respective + // leases. + testClient.SetToken(token3) + err = testClient.Auth().Token().RevokeSelf("") + if err != nil { + t.Fatal(err) + } + time.Sleep(1 * time.Second) + + expected = map[string]string{ + token1: "token", + lease1: "lease", + token2: "token", + lease2: "lease", + } + tokenRevocationValidation(t, sampleSpace, expected, leaseCache) +} + +func TestCache_TokenRevocations_IntermediateLevelToken(t *testing.T) { + coreConfig := &vault.CoreConfig{ + DisableMlock: true, + DisableCache: true, + Logger: hclog.NewNullLogger(), + LogicalBackends: map[string]logical.Factory{ + "kv": vault.LeasedPassthroughBackendFactory, + }, + } + + sampleSpace := make(map[string]string) + + cleanup, _, testClient, leaseCache := setupClusterAndAgent(namespace.RootContext(nil), t, coreConfig) + defer cleanup() + + token1 := testClient.Token() + sampleSpace[token1] = "token" + + // Mount the kv backend + err := testClient.Sys().Mount("kv", &api.MountInput{ + Type: "kv", + }) + if err != nil { + t.Fatal(err) + } + + // Create a secret in the backend + _, err = testClient.Logical().Write("kv/foo", map[string]interface{}{ + "value": "bar", + "ttl": "1h", + }) + if err != nil { + t.Fatal(err) + } + + // Read the secret and create a lease + leaseResp, err := testClient.Logical().Read("kv/foo") + if err != nil { + t.Fatal(err) + } + lease1 := leaseResp.LeaseID + sampleSpace[lease1] = "lease" + + resp, err := testClient.Logical().Write("auth/token/create", nil) + if err != nil { + t.Fatal(err) + } + token2 := resp.Auth.ClientToken + sampleSpace[token2] = "token" + + testClient.SetToken(token2) + + leaseResp, err = testClient.Logical().Read("kv/foo") + if err != nil { + t.Fatal(err) + } + lease2 := leaseResp.LeaseID + sampleSpace[lease2] = "lease" + + resp, err = testClient.Logical().Write("auth/token/create", nil) + if err != nil { + t.Fatal(err) + } + token3 := resp.Auth.ClientToken + sampleSpace[token3] = "token" + + testClient.SetToken(token3) + + leaseResp, err = testClient.Logical().Read("kv/foo") + if err != nil { + t.Fatal(err) + } + lease3 := leaseResp.LeaseID + sampleSpace[lease3] = "lease" + + expected := make(map[string]string) + for k, v := range sampleSpace { + expected[k] = v + } + tokenRevocationValidation(t, sampleSpace, expected, leaseCache) + + // Revoke the second level token. This should evict all the leases + // belonging to this token, evict entries for all the child tokens and + // their respective leases. + testClient.SetToken(token2) + err = testClient.Auth().Token().RevokeSelf("") + if err != nil { + t.Fatal(err) + } + time.Sleep(1 * time.Second) + + expected = map[string]string{ + token1: "token", + lease1: "lease", + } + tokenRevocationValidation(t, sampleSpace, expected, leaseCache) +} + +func TestCache_TokenRevocations_TopLevelToken(t *testing.T) { + coreConfig := &vault.CoreConfig{ + DisableMlock: true, + DisableCache: true, + Logger: hclog.NewNullLogger(), + LogicalBackends: map[string]logical.Factory{ + "kv": vault.LeasedPassthroughBackendFactory, + }, + } + + sampleSpace := make(map[string]string) + + cleanup, _, testClient, leaseCache := setupClusterAndAgent(namespace.RootContext(nil), t, coreConfig) + defer cleanup() + + token1 := testClient.Token() + sampleSpace[token1] = "token" + + // Mount the kv backend + err := testClient.Sys().Mount("kv", &api.MountInput{ + Type: "kv", + }) + if err != nil { + t.Fatal(err) + } + + // Create a secret in the backend + _, err = testClient.Logical().Write("kv/foo", map[string]interface{}{ + "value": "bar", + "ttl": "1h", + }) + if err != nil { + t.Fatal(err) + } + + // Read the secret and create a lease + leaseResp, err := testClient.Logical().Read("kv/foo") + if err != nil { + t.Fatal(err) + } + lease1 := leaseResp.LeaseID + sampleSpace[lease1] = "lease" + + resp, err := testClient.Logical().Write("auth/token/create", nil) + if err != nil { + t.Fatal(err) + } + token2 := resp.Auth.ClientToken + sampleSpace[token2] = "token" + + testClient.SetToken(token2) + + leaseResp, err = testClient.Logical().Read("kv/foo") + if err != nil { + t.Fatal(err) + } + lease2 := leaseResp.LeaseID + sampleSpace[lease2] = "lease" + + resp, err = testClient.Logical().Write("auth/token/create", nil) + if err != nil { + t.Fatal(err) + } + token3 := resp.Auth.ClientToken + sampleSpace[token3] = "token" + + testClient.SetToken(token3) + + leaseResp, err = testClient.Logical().Read("kv/foo") + if err != nil { + t.Fatal(err) + } + lease3 := leaseResp.LeaseID + sampleSpace[lease3] = "lease" + + expected := make(map[string]string) + for k, v := range sampleSpace { + expected[k] = v + } + tokenRevocationValidation(t, sampleSpace, expected, leaseCache) + + // Revoke the top level token. This should evict all the leases belonging + // to this token, evict entries for all the child tokens and their + // respective leases. + testClient.SetToken(token1) + err = testClient.Auth().Token().RevokeSelf("") + if err != nil { + t.Fatal(err) + } + time.Sleep(1 * time.Second) + + expected = make(map[string]string) + tokenRevocationValidation(t, sampleSpace, expected, leaseCache) +} + +func TestCache_TokenRevocations_Shutdown(t *testing.T) { + coreConfig := &vault.CoreConfig{ + DisableMlock: true, + DisableCache: true, + Logger: hclog.NewNullLogger(), + LogicalBackends: map[string]logical.Factory{ + "kv": vault.LeasedPassthroughBackendFactory, + }, + } + + sampleSpace := make(map[string]string) + + ctx, rootCancelFunc := context.WithCancel(namespace.RootContext(nil)) + cleanup, _, testClient, leaseCache := setupClusterAndAgent(ctx, t, coreConfig) + defer cleanup() + + token1 := testClient.Token() + sampleSpace[token1] = "token" + + // Mount the kv backend + err := testClient.Sys().Mount("kv", &api.MountInput{ + Type: "kv", + }) + if err != nil { + t.Fatal(err) + } + + // Create a secret in the backend + _, err = testClient.Logical().Write("kv/foo", map[string]interface{}{ + "value": "bar", + "ttl": "1h", + }) + if err != nil { + t.Fatal(err) + } + + // Read the secret and create a lease + leaseResp, err := testClient.Logical().Read("kv/foo") + if err != nil { + t.Fatal(err) + } + lease1 := leaseResp.LeaseID + sampleSpace[lease1] = "lease" + + resp, err := testClient.Logical().Write("auth/token/create", nil) + if err != nil { + t.Fatal(err) + } + token2 := resp.Auth.ClientToken + sampleSpace[token2] = "token" + + testClient.SetToken(token2) + + leaseResp, err = testClient.Logical().Read("kv/foo") + if err != nil { + t.Fatal(err) + } + lease2 := leaseResp.LeaseID + sampleSpace[lease2] = "lease" + + resp, err = testClient.Logical().Write("auth/token/create", nil) + if err != nil { + t.Fatal(err) + } + token3 := resp.Auth.ClientToken + sampleSpace[token3] = "token" + + testClient.SetToken(token3) + + leaseResp, err = testClient.Logical().Read("kv/foo") + if err != nil { + t.Fatal(err) + } + lease3 := leaseResp.LeaseID + sampleSpace[lease3] = "lease" + + expected := make(map[string]string) + for k, v := range sampleSpace { + expected[k] = v + } + tokenRevocationValidation(t, sampleSpace, expected, leaseCache) + + rootCancelFunc() + time.Sleep(1 * time.Second) + + // Ensure that all the entries are now gone + expected = make(map[string]string) + tokenRevocationValidation(t, sampleSpace, expected, leaseCache) +} + +func TestCache_TokenRevocations_BaseContextCancellation(t *testing.T) { + coreConfig := &vault.CoreConfig{ + DisableMlock: true, + DisableCache: true, + Logger: hclog.NewNullLogger(), + LogicalBackends: map[string]logical.Factory{ + "kv": vault.LeasedPassthroughBackendFactory, + }, + } + + sampleSpace := make(map[string]string) + + cleanup, _, testClient, leaseCache := setupClusterAndAgent(namespace.RootContext(nil), t, coreConfig) + defer cleanup() + + token1 := testClient.Token() + sampleSpace[token1] = "token" + + // Mount the kv backend + err := testClient.Sys().Mount("kv", &api.MountInput{ + Type: "kv", + }) + if err != nil { + t.Fatal(err) + } + + // Create a secret in the backend + _, err = testClient.Logical().Write("kv/foo", map[string]interface{}{ + "value": "bar", + "ttl": "1h", + }) + if err != nil { + t.Fatal(err) + } + + // Read the secret and create a lease + leaseResp, err := testClient.Logical().Read("kv/foo") + if err != nil { + t.Fatal(err) + } + lease1 := leaseResp.LeaseID + sampleSpace[lease1] = "lease" + + resp, err := testClient.Logical().Write("auth/token/create", nil) + if err != nil { + t.Fatal(err) + } + token2 := resp.Auth.ClientToken + sampleSpace[token2] = "token" + + testClient.SetToken(token2) + + leaseResp, err = testClient.Logical().Read("kv/foo") + if err != nil { + t.Fatal(err) + } + lease2 := leaseResp.LeaseID + sampleSpace[lease2] = "lease" + + resp, err = testClient.Logical().Write("auth/token/create", nil) + if err != nil { + t.Fatal(err) + } + token3 := resp.Auth.ClientToken + sampleSpace[token3] = "token" + + testClient.SetToken(token3) + + leaseResp, err = testClient.Logical().Read("kv/foo") + if err != nil { + t.Fatal(err) + } + lease3 := leaseResp.LeaseID + sampleSpace[lease3] = "lease" + + expected := make(map[string]string) + for k, v := range sampleSpace { + expected[k] = v + } + tokenRevocationValidation(t, sampleSpace, expected, leaseCache) + + // Cancel the base context of the lease cache. This should trigger + // evictions of all the entries from the cache. + leaseCache.baseCtxInfo.CancelFunc() + time.Sleep(1 * time.Second) + + // Ensure that all the entries are now gone + expected = make(map[string]string) + tokenRevocationValidation(t, sampleSpace, expected, leaseCache) +} + +func TestCache_NonCacheable(t *testing.T) { + coreConfig := &vault.CoreConfig{ + DisableMlock: true, + DisableCache: true, + Logger: hclog.NewNullLogger(), + LogicalBackends: map[string]logical.Factory{ + "kv": kv.Factory, + }, + } + + cleanup, _, testClient, _ := setupClusterAndAgent(namespace.RootContext(nil), t, coreConfig) + defer cleanup() + + // Query mounts first + origMounts, err := testClient.Sys().ListMounts() + if err != nil { + t.Fatal(err) + } + + // Mount a kv backend + if err := testClient.Sys().Mount("kv", &api.MountInput{ + Type: "kv", + Options: map[string]string{ + "version": "2", + }, + }); err != nil { + t.Fatal(err) + } + + // Query mounts again + newMounts, err := testClient.Sys().ListMounts() + if err != nil { + t.Fatal(err) + } + + if diff := deep.Equal(origMounts, newMounts); diff == nil { + t.Logf("response #1: %#v", origMounts) + t.Logf("response #2: %#v", newMounts) + t.Fatal("expected requests to be not cached") + } +} + +func TestCache_AuthResponse(t *testing.T) { + cleanup, _, testClient, _ := setupClusterAndAgent(namespace.RootContext(nil), t, nil) + defer cleanup() + + resp, err := testClient.Logical().Write("auth/token/create", nil) + if err != nil { + t.Fatal(err) + } + token := resp.Auth.ClientToken + testClient.SetToken(token) + + authTokeCreateReq := func(t *testing.T, policies map[string]interface{}) *api.Secret { + resp, err := testClient.Logical().Write("auth/token/create", policies) + if err != nil { + t.Fatal(err) + } + if resp.Auth == nil || resp.Auth.ClientToken == "" { + t.Fatalf("expected a valid client token in the response, got = %#v", resp) + } + + return resp + } + + // Test on auth response by creating a child token + { + proxiedResp := authTokeCreateReq(t, map[string]interface{}{ + "policies": "default", + }) + + cachedResp := authTokeCreateReq(t, map[string]interface{}{ + "policies": "default", + }) + + if diff := deep.Equal(proxiedResp.Auth.ClientToken, cachedResp.Auth.ClientToken); diff != nil { + t.Fatal(diff) + } + } + + // Test on *non-renewable* auth response by creating a child root token + { + proxiedResp := authTokeCreateReq(t, nil) + + cachedResp := authTokeCreateReq(t, nil) + + if diff := deep.Equal(proxiedResp.Auth.ClientToken, cachedResp.Auth.ClientToken); diff != nil { + t.Fatal(diff) + } + } +} + +func TestCache_LeaseResponse(t *testing.T) { + coreConfig := &vault.CoreConfig{ + DisableMlock: true, + DisableCache: true, + Logger: hclog.NewNullLogger(), + LogicalBackends: map[string]logical.Factory{ + "kv": vault.LeasedPassthroughBackendFactory, + }, + } + + cleanup, client, testClient, _ := setupClusterAndAgent(namespace.RootContext(nil), t, coreConfig) + defer cleanup() + + err := client.Sys().Mount("kv", &api.MountInput{ + Type: "kv", + }) + if err != nil { + t.Fatal(err) + } + + // Test proxy by issuing two different requests + { + // Write data to the lease-kv backend + _, err := testClient.Logical().Write("kv/foo", map[string]interface{}{ + "value": "bar", + "ttl": "1h", + }) + if err != nil { + t.Fatal(err) + } + _, err = testClient.Logical().Write("kv/foobar", map[string]interface{}{ + "value": "bar", + "ttl": "1h", + }) + if err != nil { + t.Fatal(err) + } + + firstResp, err := testClient.Logical().Read("kv/foo") + if err != nil { + t.Fatal(err) + } + + secondResp, err := testClient.Logical().Read("kv/foobar") + if err != nil { + t.Fatal(err) + } + + if diff := deep.Equal(firstResp, secondResp); diff == nil { + t.Logf("response: %#v", firstResp) + t.Fatal("expected proxied responses, got cached response on second request") + } + } + + // Test caching behavior by issue the same request twice + { + _, err := testClient.Logical().Write("kv/baz", map[string]interface{}{ + "value": "foo", + "ttl": "1h", + }) + if err != nil { + t.Fatal(err) + } + + proxiedResp, err := testClient.Logical().Read("kv/baz") + if err != nil { + t.Fatal(err) + } + + cachedResp, err := testClient.Logical().Read("kv/baz") + if err != nil { + t.Fatal(err) + } + + if diff := deep.Equal(proxiedResp, cachedResp); diff != nil { + t.Fatal(diff) + } + } +} diff --git a/command/agent/cache/cachememdb/cache_memdb.go b/command/agent/cache/cachememdb/cache_memdb.go new file mode 100644 index 000000000000..8f9aabfdd295 --- /dev/null +++ b/command/agent/cache/cachememdb/cache_memdb.go @@ -0,0 +1,265 @@ +package cachememdb + +import ( + "errors" + "fmt" + + memdb "github.com/hashicorp/go-memdb" +) + +const ( + tableNameIndexer = "indexer" +) + +// CacheMemDB is the underlying cache database for storing indexes. +type CacheMemDB struct { + db *memdb.MemDB +} + +// New creates a new instance of CacheMemDB. +func New() (*CacheMemDB, error) { + db, err := newDB() + if err != nil { + return nil, err + } + + return &CacheMemDB{ + db: db, + }, nil +} + +func newDB() (*memdb.MemDB, error) { + cacheSchema := &memdb.DBSchema{ + Tables: map[string]*memdb.TableSchema{ + tableNameIndexer: &memdb.TableSchema{ + Name: tableNameIndexer, + Indexes: map[string]*memdb.IndexSchema{ + // This index enables fetching the cached item based on the + // identifier of the index. + IndexNameID: &memdb.IndexSchema{ + Name: IndexNameID, + Unique: true, + Indexer: &memdb.StringFieldIndex{ + Field: "ID", + }, + }, + // This index enables fetching all the entries in cache for + // a given request path, in a given namespace. + IndexNameRequestPath: &memdb.IndexSchema{ + Name: IndexNameRequestPath, + Unique: false, + Indexer: &memdb.CompoundIndex{ + Indexes: []memdb.Indexer{ + &memdb.StringFieldIndex{ + Field: "Namespace", + }, + &memdb.StringFieldIndex{ + Field: "RequestPath", + }, + }, + }, + }, + // This index enables fetching all the entries in cache + // belonging to the leases of a given token. + IndexNameLeaseToken: &memdb.IndexSchema{ + Name: IndexNameLeaseToken, + Unique: false, + AllowMissing: true, + Indexer: &memdb.StringFieldIndex{ + Field: "LeaseToken", + }, + }, + // This index enables fetching all the entries in cache + // that are tied to the given token, regardless of the + // entries belonging to the token or belonging to the + // lease. + IndexNameToken: &memdb.IndexSchema{ + Name: IndexNameToken, + Unique: true, + AllowMissing: true, + Indexer: &memdb.StringFieldIndex{ + Field: "Token", + }, + }, + // This index enables fetching all the entries in cache for + // the given parent token. + IndexNameTokenParent: &memdb.IndexSchema{ + Name: IndexNameTokenParent, + Unique: false, + AllowMissing: true, + Indexer: &memdb.StringFieldIndex{ + Field: "TokenParent", + }, + }, + // This index enables fetching all the entries in cache for + // the given accessor. + IndexNameTokenAccessor: &memdb.IndexSchema{ + Name: IndexNameTokenAccessor, + Unique: true, + AllowMissing: true, + Indexer: &memdb.StringFieldIndex{ + Field: "TokenAccessor", + }, + }, + // This index enables fetching all the entries in cache for + // the given lease identifier. + IndexNameLease: &memdb.IndexSchema{ + Name: IndexNameLease, + Unique: true, + AllowMissing: true, + Indexer: &memdb.StringFieldIndex{ + Field: "Lease", + }, + }, + }, + }, + }, + } + + db, err := memdb.NewMemDB(cacheSchema) + if err != nil { + return nil, err + } + return db, nil +} + +// Get returns the index based on the indexer and the index values provided. +func (c *CacheMemDB) Get(indexName string, indexValues ...interface{}) (*Index, error) { + if !validIndexName(indexName) { + return nil, fmt.Errorf("invalid index name %q", indexName) + } + + raw, err := c.db.Txn(false).First(tableNameIndexer, indexName, indexValues...) + if err != nil { + return nil, err + } + + if raw == nil { + return nil, nil + } + + index, ok := raw.(*Index) + if !ok { + return nil, errors.New("unable to parse index value from the cache") + } + + return index, nil +} + +// Set stores the index into the cache. +func (c *CacheMemDB) Set(index *Index) error { + if index == nil { + return errors.New("nil index provided") + } + + txn := c.db.Txn(true) + defer txn.Abort() + + if err := txn.Insert(tableNameIndexer, index); err != nil { + return fmt.Errorf("unable to insert index into cache: %v", err) + } + + txn.Commit() + + return nil +} + +// GetByPrefix returns all the cached indexes based on the index name and the +// value prefix. +func (c *CacheMemDB) GetByPrefix(indexName string, indexValues ...interface{}) ([]*Index, error) { + if !validIndexName(indexName) { + return nil, fmt.Errorf("invalid index name %q", indexName) + } + + indexName = indexName + "_prefix" + + // Get all the objects + iter, err := c.db.Txn(false).Get(tableNameIndexer, indexName, indexValues...) + if err != nil { + return nil, err + } + + var indexes []*Index + for { + obj := iter.Next() + if obj == nil { + break + } + index, ok := obj.(*Index) + if !ok { + return nil, fmt.Errorf("failed to cast cached index") + } + + indexes = append(indexes, index) + } + + return indexes, nil +} + +// Evict removes an index from the cache based on index name and value. +func (c *CacheMemDB) Evict(indexName string, indexValues ...interface{}) error { + index, err := c.Get(indexName, indexValues...) + if err != nil { + return fmt.Errorf("unable to fetch index on cache deletion: %v", err) + } + + if index == nil { + return nil + } + + txn := c.db.Txn(true) + defer txn.Abort() + + if err := txn.Delete(tableNameIndexer, index); err != nil { + return fmt.Errorf("unable to delete index from cache: %v", err) + } + + txn.Commit() + + return nil +} + +// EvictAll removes all matching indexes from the cache based on index name and value. +func (c *CacheMemDB) EvictAll(indexName, indexValue string) error { + return c.batchEvict(false, indexName, indexValue) +} + +// EvictByPrefix removes all matching prefix indexes from the cache based on index name and prefix. +func (c *CacheMemDB) EvictByPrefix(indexName, indexPrefix string) error { + return c.batchEvict(true, indexName, indexPrefix) +} + +// batchEvict is a helper that supports eviction based on absolute and prefixed index values. +func (c *CacheMemDB) batchEvict(isPrefix bool, indexName string, indexValues ...interface{}) error { + if !validIndexName(indexName) { + return fmt.Errorf("invalid index name %q", indexName) + } + + if isPrefix { + indexName = indexName + "_prefix" + } + + txn := c.db.Txn(true) + defer txn.Abort() + + _, err := txn.DeleteAll(tableNameIndexer, indexName, indexValues...) + if err != nil { + return err + } + + txn.Commit() + + return nil +} + +// Flush resets the underlying cache object. +func (c *CacheMemDB) Flush() error { + newDB, err := newDB() + if err != nil { + return err + } + + c.db = newDB + + return nil +} diff --git a/command/agent/cache/cachememdb/cache_memdb_test.go b/command/agent/cache/cachememdb/cache_memdb_test.go new file mode 100644 index 000000000000..a8af42f5356f --- /dev/null +++ b/command/agent/cache/cachememdb/cache_memdb_test.go @@ -0,0 +1,388 @@ +package cachememdb + +import ( + "context" + "testing" + + "github.com/go-test/deep" +) + +func testContextInfo() *ContextInfo { + ctx, cancelFunc := context.WithCancel(context.Background()) + + return &ContextInfo{ + Ctx: ctx, + CancelFunc: cancelFunc, + } +} + +func TestNew(t *testing.T) { + _, err := New() + if err != nil { + t.Fatal(err) + } +} + +func TestCacheMemDB_Get(t *testing.T) { + cache, err := New() + if err != nil { + t.Fatal(err) + } + + // Test invalid index name + _, err = cache.Get("foo", "bar") + if err == nil { + t.Fatal("expected error") + } + + // Test on empty cache + index, err := cache.Get(IndexNameID, "foo") + if err != nil { + t.Fatal(err) + } + if index != nil { + t.Fatalf("expected nil index, got: %v", index) + } + + // Populate cache + in := &Index{ + ID: "test_id", + Namespace: "test_ns/", + RequestPath: "/v1/request/path", + Token: "test_token", + TokenAccessor: "test_accessor", + Lease: "test_lease", + Response: []byte("hello world"), + } + + if err := cache.Set(in); err != nil { + t.Fatal(err) + } + + testCases := []struct { + name string + indexName string + indexValues []interface{} + }{ + { + "by_index_id", + "id", + []interface{}{in.ID}, + }, + { + "by_request_path", + "request_path", + []interface{}{in.Namespace, in.RequestPath}, + }, + { + "by_lease", + "lease", + []interface{}{in.Lease}, + }, + { + "by_token", + "token", + []interface{}{in.Token}, + }, + { + "by_token_accessor", + "token_accessor", + []interface{}{in.TokenAccessor}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + out, err := cache.Get(tc.indexName, tc.indexValues...) + if err != nil { + t.Fatal(err) + } + if diff := deep.Equal(in, out); diff != nil { + t.Fatal(diff) + } + }) + } +} + +func TestCacheMemDB_GetByPrefix(t *testing.T) { + cache, err := New() + if err != nil { + t.Fatal(err) + } + + // Test invalid index name + _, err = cache.GetByPrefix("foo", "bar", "baz") + if err == nil { + t.Fatal("expected error") + } + + // Test on empty cache + index, err := cache.GetByPrefix(IndexNameRequestPath, "foo", "bar") + if err != nil { + t.Fatal(err) + } + if index != nil { + t.Fatalf("expected nil index, got: %v", index) + } + + // Populate cache + in := &Index{ + ID: "test_id", + Namespace: "test_ns/", + RequestPath: "/v1/request/path/1", + Token: "test_token", + TokenAccessor: "test_accessor", + Lease: "path/to/test_lease/1", + Response: []byte("hello world"), + } + + if err := cache.Set(in); err != nil { + t.Fatal(err) + } + + // Populate cache + in2 := &Index{ + ID: "test_id_2", + Namespace: "test_ns/", + RequestPath: "/v1/request/path/2", + Token: "test_token", + TokenAccessor: "test_accessor", + Lease: "path/to/test_lease/2", + Response: []byte("hello world"), + } + + if err := cache.Set(in2); err != nil { + t.Fatal(err) + } + + testCases := []struct { + name string + indexName string + indexValues []interface{} + }{ + { + "by_request_path", + "request_path", + []interface{}{"test_ns/", "/v1/request/path"}, + }, + { + "by_lease", + "lease", + []interface{}{"path/to/test_lease"}, + }, + { + "by_token", + "token", + []interface{}{"test_token"}, + }, + { + "by_token_accessor", + "token_accessor", + []interface{}{"test_accessor"}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + out, err := cache.GetByPrefix(tc.indexName, tc.indexValues...) + if err != nil { + t.Fatal(err) + } + + if diff := deep.Equal([]*Index{in, in2}, out); diff != nil { + t.Fatal(diff) + } + }) + } +} + +func TestCacheMemDB_Set(t *testing.T) { + cache, err := New() + if err != nil { + t.Fatal(err) + } + + testCases := []struct { + name string + index *Index + wantErr bool + }{ + { + "nil", + nil, + true, + }, + { + "empty_fields", + &Index{}, + true, + }, + { + "missing_required_fields", + &Index{ + Lease: "foo", + }, + true, + }, + { + "all_fields", + &Index{ + ID: "test_id", + Namespace: "test_ns/", + RequestPath: "/v1/request/path", + Token: "test_token", + TokenAccessor: "test_accessor", + Lease: "test_lease", + RenewCtxInfo: testContextInfo(), + }, + false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + if err := cache.Set(tc.index); (err != nil) != tc.wantErr { + t.Fatalf("CacheMemDB.Set() error = %v, wantErr = %v", err, tc.wantErr) + } + }) + } +} + +func TestCacheMemDB_Evict(t *testing.T) { + cache, err := New() + if err != nil { + t.Fatal(err) + } + + // Test on empty cache + if err := cache.Evict(IndexNameID, "foo"); err != nil { + t.Fatal(err) + } + + testIndex := &Index{ + ID: "test_id", + Namespace: "test_ns/", + RequestPath: "/v1/request/path", + Token: "test_token", + TokenAccessor: "test_token_accessor", + Lease: "test_lease", + RenewCtxInfo: testContextInfo(), + } + + testCases := []struct { + name string + indexName string + indexValues []interface{} + insertIndex *Index + wantErr bool + }{ + { + "empty_params", + "", + []interface{}{""}, + nil, + true, + }, + { + "invalid_params", + "foo", + []interface{}{"bar"}, + nil, + true, + }, + { + "by_id", + "id", + []interface{}{"test_id"}, + testIndex, + false, + }, + { + "by_request_path", + "request_path", + []interface{}{"test_ns/", "/v1/request/path"}, + testIndex, + false, + }, + { + "by_token", + "token", + []interface{}{"test_token"}, + testIndex, + false, + }, + { + "by_token_accessor", + "token_accessor", + []interface{}{"test_accessor"}, + testIndex, + false, + }, + { + "by_lease", + "lease", + []interface{}{"test_lease"}, + testIndex, + false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + if tc.insertIndex != nil { + if err := cache.Set(tc.insertIndex); err != nil { + t.Fatal(err) + } + } + + if err := cache.Evict(tc.indexName, tc.indexValues...); (err != nil) != tc.wantErr { + t.Fatal(err) + } + + // Verify that the cache doesn't contain the entry any more + index, err := cache.Get(tc.indexName, tc.indexValues...) + if (err != nil) != tc.wantErr { + t.Fatal(err) + } + + if index != nil { + t.Fatalf("expected nil entry, got = %#v", index) + } + }) + } +} + +func TestCacheMemDB_Flush(t *testing.T) { + cache, err := New() + if err != nil { + t.Fatal(err) + } + + // Populate cache + in := &Index{ + ID: "test_id", + Token: "test_token", + Lease: "test_lease", + Namespace: "test_ns/", + RequestPath: "/v1/request/path", + Response: []byte("hello world"), + } + + if err := cache.Set(in); err != nil { + t.Fatal(err) + } + + // Reset the cache + if err := cache.Flush(); err != nil { + t.Fatal(err) + } + + // Check the cache doesn't contain inserted index + out, err := cache.Get(IndexNameID, "test_id") + if err != nil { + t.Fatal(err) + } + if out != nil { + t.Fatalf("expected cache to be empty, got = %v", out) + } +} diff --git a/command/agent/cache/cachememdb/index.go b/command/agent/cache/cachememdb/index.go new file mode 100644 index 000000000000..4d932ca4f2fd --- /dev/null +++ b/command/agent/cache/cachememdb/index.go @@ -0,0 +1,97 @@ +package cachememdb + +import "context" + +type ContextInfo struct { + Ctx context.Context + CancelFunc context.CancelFunc + DoneCh chan struct{} +} + +// Index holds the response to be cached along with multiple other values that +// serve as pointers to refer back to this index. +type Index struct { + // ID is a value that uniquely represents the request held by this + // index. This is computed by serializing and hashing the response object. + // Required: true, Unique: true + ID string + + // Token is the token that fetched the response held by this index + // Required: true, Unique: true + Token string + + // TokenParent is the parent token of the token held by this index + // Required: false, Unique: false + TokenParent string + + // TokenAccessor is the accessor of the token being cached in this index + // Required: true, Unique: true + TokenAccessor string + + // Namespace is the namespace that was provided in the request path as the + // Vault namespace to query + Namespace string + + // RequestPath is the path of the request that resulted in the response + // held by this index. + // Required: true, Unique: false + RequestPath string + + // Lease is the identifier of the lease in Vault, that belongs to the + // response held by this index. + // Required: false, Unique: true + Lease string + + // LeaseToken is the identifier of the token that created the lease held by + // this index. + // Required: false, Unique: false + LeaseToken string + + // Response is the serialized response object that the agent is caching. + Response []byte + + // RenewCtxInfo holds the context and the corresponding cancel func for the + // goroutine that manages the renewal of the secret belonging to the + // response in this index. + RenewCtxInfo *ContextInfo +} + +type IndexName uint32 + +const ( + // IndexNameID is the ID of the index constructed from the serialized request. + IndexNameID = "id" + + // IndexNameLease is the lease of the index. + IndexNameLease = "lease" + + // IndexNameRequestPath is the request path of the index. + IndexNameRequestPath = "request_path" + + // IndexNameToken is the token of the index. + IndexNameToken = "token" + + // IndexNameTokenAccessor is the token accessor of the index. + IndexNameTokenAccessor = "token_accessor" + + // IndexNameTokenParent is the token parent of the index. + IndexNameTokenParent = "token_parent" + + // IndexNameLeaseToken is the token that created the lease. + IndexNameLeaseToken = "lease_token" +) + +func validIndexName(indexName string) bool { + switch indexName { + case "id": + case "lease": + case "request_path": + case "token": + case "token_accessor": + case "token_parent": + case "lease_token": + default: + return false + } + return true +} diff --git a/command/agent/cache/handler.go b/command/agent/cache/handler.go new file mode 100644 index 000000000000..10c36c7dd22b --- /dev/null +++ b/command/agent/cache/handler.go @@ -0,0 +1,155 @@ +package cache + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "io/ioutil" + "net/http" + + "github.com/hashicorp/errwrap" + hclog "github.com/hashicorp/go-hclog" + "github.com/hashicorp/vault/api" + "github.com/hashicorp/vault/helper/consts" + vaulthttp "github.com/hashicorp/vault/http" + "github.com/hashicorp/vault/logical" +) + +func Handler(ctx context.Context, logger hclog.Logger, proxier Proxier, useAutoAuthToken bool, client *api.Client) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + logger.Info("received request", "path", r.URL.Path, "method", r.Method) + + token := r.Header.Get(consts.AuthHeaderName) + if token == "" && useAutoAuthToken { + logger.Debug("using auto auth token") + token = client.Token() + } + + // Parse and reset body. + reqBody, err := ioutil.ReadAll(r.Body) + if err != nil { + logger.Error("failed to read request body") + respondError(w, http.StatusInternalServerError, errors.New("failed to read request body")) + } + if r.Body != nil { + r.Body.Close() + } + r.Body = ioutil.NopCloser(bytes.NewBuffer(reqBody)) + req := &SendRequest{ + Token: token, + Request: r, + RequestBody: reqBody, + } + + resp, err := proxier.Send(ctx, req) + if err != nil { + respondError(w, http.StatusInternalServerError, errwrap.Wrapf("failed to get the response: {{err}}", err)) + return + } + + err = processTokenLookupResponse(ctx, logger, useAutoAuthToken, client, req, resp) + if err != nil { + respondError(w, http.StatusInternalServerError, errwrap.Wrapf("failed to process token lookup response: {{err}}", err)) + return + } + + defer resp.Response.Body.Close() + + copyHeader(w.Header(), resp.Response.Header) + w.WriteHeader(resp.Response.StatusCode) + io.Copy(w, resp.Response.Body) + return + }) +} + +// processTokenLookupResponse checks if the request was one of token +// lookup-self. If the auto-auth token was used to perform lookup-self, the +// identifier of the token and its accessor same will be stripped off of the +// response. +func processTokenLookupResponse(ctx context.Context, logger hclog.Logger, useAutoAuthToken bool, client *api.Client, req *SendRequest, resp *SendResponse) error { + // If auto-auth token is not being used, there is nothing to do. + if !useAutoAuthToken { + return nil + } + + // If lookup responded with non 200 status, there is nothing to do. + if resp.Response.StatusCode != http.StatusOK { + return nil + } + + // Strip-off namespace related information from the request and get the + // relative path of the request. + _, path := deriveNamespaceAndRevocationPath(req) + if path == vaultPathTokenLookupSelf { + logger.Info("stripping auto-auth token from the response", "path", req.Request.URL.Path, "method", req.Request.Method) + secret, err := api.ParseSecret(bytes.NewBuffer(resp.ResponseBody)) + if err != nil { + return fmt.Errorf("failed to parse token lookup response: %v", err) + } + if secret != nil && secret.Data != nil && secret.Data["id"] != nil { + token, ok := secret.Data["id"].(string) + if !ok { + return fmt.Errorf("failed to type assert the token id in the response") + } + if token == client.Token() { + delete(secret.Data, "id") + delete(secret.Data, "accessor") + } + + bodyBytes, err := json.Marshal(secret) + if err != nil { + return err + } + if resp.Response.Body != nil { + resp.Response.Body.Close() + } + resp.Response.Body = ioutil.NopCloser(bytes.NewBuffer(bodyBytes)) + resp.Response.ContentLength = int64(len(bodyBytes)) + + // Serialize and re-read the reponse + var respBytes bytes.Buffer + err = resp.Response.Write(&respBytes) + if err != nil { + return fmt.Errorf("failed to serialize the updated response: %v", err) + } + + updatedResponse, err := http.ReadResponse(bufio.NewReader(bytes.NewReader(respBytes.Bytes())), nil) + if err != nil { + return fmt.Errorf("failed to deserialize the updated response: %v", err) + } + + resp.Response = &api.Response{ + Response: updatedResponse, + } + resp.ResponseBody = bodyBytes + } + } + return nil +} + +func copyHeader(dst, src http.Header) { + for k, vv := range src { + for _, v := range vv { + dst.Add(k, v) + } + } +} + +func respondError(w http.ResponseWriter, status int, err error) { + logical.AdjustErrorStatusCode(&status, err) + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + + resp := &vaulthttp.ErrorResponse{Errors: make([]string, 0, 1)} + if err != nil { + resp.Errors = append(resp.Errors, err.Error()) + } + + enc := json.NewEncoder(w) + enc.Encode(resp) +} diff --git a/command/agent/cache/lease_cache.go b/command/agent/cache/lease_cache.go new file mode 100644 index 000000000000..a998ec96fb51 --- /dev/null +++ b/command/agent/cache/lease_cache.go @@ -0,0 +1,813 @@ +package cache + +import ( + "bufio" + "bytes" + "context" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "io" + "io/ioutil" + "net/http" + "strings" + + "github.com/hashicorp/errwrap" + hclog "github.com/hashicorp/go-hclog" + "github.com/hashicorp/vault/api" + cachememdb "github.com/hashicorp/vault/command/agent/cache/cachememdb" + "github.com/hashicorp/vault/helper/consts" + "github.com/hashicorp/vault/helper/jsonutil" + "github.com/hashicorp/vault/helper/namespace" + nshelper "github.com/hashicorp/vault/helper/namespace" +) + +const ( + vaultPathTokenCreate = "/v1/auth/token/create" + vaultPathTokenRevoke = "/v1/auth/token/revoke" + vaultPathTokenRevokeSelf = "/v1/auth/token/revoke-self" + vaultPathTokenRevokeAccessor = "/v1/auth/token/revoke-accessor" + vaultPathTokenRevokeOrphan = "/v1/auth/token/revoke-orphan" + vaultPathTokenLookupSelf = "/v1/auth/token/lookup-self" + vaultPathLeaseRevoke = "/v1/sys/leases/revoke" + vaultPathLeaseRevokeForce = "/v1/sys/leases/revoke-force" + vaultPathLeaseRevokePrefix = "/v1/sys/leases/revoke-prefix" +) + +var ( + contextIndexID = contextIndex{} + errInvalidType = errors.New("invalid type provided") + revocationPaths = []string{ + strings.TrimPrefix(vaultPathTokenRevoke, "/v1"), + strings.TrimPrefix(vaultPathTokenRevokeSelf, "/v1"), + strings.TrimPrefix(vaultPathTokenRevokeAccessor, "/v1"), + strings.TrimPrefix(vaultPathTokenRevokeOrphan, "/v1"), + strings.TrimPrefix(vaultPathLeaseRevoke, "/v1"), + strings.TrimPrefix(vaultPathLeaseRevokeForce, "/v1"), + strings.TrimPrefix(vaultPathLeaseRevokePrefix, "/v1"), + } +) + +type contextIndex struct{} + +type cacheClearRequest struct { + Type string `json:"type"` + Value string `json:"value"` + Namespace string `json:"namespace"` +} + +// LeaseCache is an implementation of Proxier that handles +// the caching of responses. It passes the incoming request +// to an underlying Proxier implementation. +type LeaseCache struct { + proxier Proxier + logger hclog.Logger + db *cachememdb.CacheMemDB + baseCtxInfo *ContextInfo +} + +// LeaseCacheConfig is the configuration for initializing a new +// Lease. +type LeaseCacheConfig struct { + BaseContext context.Context + Proxier Proxier + Logger hclog.Logger +} + +// ContextInfo holds a derived context and cancelFunc pair. +type ContextInfo struct { + Ctx context.Context + CancelFunc context.CancelFunc + DoneCh chan struct{} +} + +// NewLeaseCache creates a new instance of a LeaseCache. +func NewLeaseCache(conf *LeaseCacheConfig) (*LeaseCache, error) { + if conf == nil { + return nil, errors.New("nil configuration provided") + } + + if conf.Proxier == nil || conf.Logger == nil { + return nil, fmt.Errorf("missing configuration required params: %v", conf) + } + + db, err := cachememdb.New() + if err != nil { + return nil, err + } + + // Create a base context for the lease cache layer + baseCtx, baseCancelFunc := context.WithCancel(conf.BaseContext) + baseCtxInfo := &ContextInfo{ + Ctx: baseCtx, + CancelFunc: baseCancelFunc, + } + + return &LeaseCache{ + proxier: conf.Proxier, + logger: conf.Logger, + db: db, + baseCtxInfo: baseCtxInfo, + }, nil +} + +// Send performs a cache lookup on the incoming request. If it's a cache hit, +// it will return the cached response, otherwise it will delegate to the +// underlying Proxier and cache the received response. +func (c *LeaseCache) Send(ctx context.Context, req *SendRequest) (*SendResponse, error) { + // Compute the index ID + id, err := computeIndexID(req) + if err != nil { + c.logger.Error("failed to compute cache key", "error", err) + return nil, err + } + + // Check if the response for this request is already in the cache + index, err := c.db.Get(cachememdb.IndexNameID, id) + if err != nil { + return nil, err + } + + // Cached request is found, deserialize the response and return early + if index != nil { + c.logger.Debug("returning cached response", "path", req.Request.URL.Path) + + reader := bufio.NewReader(bytes.NewReader(index.Response)) + resp, err := http.ReadResponse(reader, nil) + if err != nil { + c.logger.Error("failed to deserialize response", "error", err) + return nil, err + } + + return &SendResponse{ + Response: &api.Response{ + Response: resp, + }, + ResponseBody: index.Response, + }, nil + } + + c.logger.Debug("forwarding request", "path", req.Request.URL.Path, "method", req.Request.Method) + + // Pass the request down and get a response + resp, err := c.proxier.Send(ctx, req) + if err != nil { + return nil, err + } + + // Get the namespace from the request header + namespace := req.Request.Header.Get(consts.NamespaceHeaderName) + // We need to populate an empty value since go-memdb will skip over indexes + // that contain empty values. + if namespace == "" { + namespace = "root/" + } + + // Build the index to cache based on the response received + index = &cachememdb.Index{ + ID: id, + Namespace: namespace, + RequestPath: req.Request.URL.Path, + } + + secret, err := api.ParseSecret(bytes.NewBuffer(resp.ResponseBody)) + if err != nil { + c.logger.Error("failed to parse response as secret", "error", err) + return nil, err + } + + isRevocation, err := c.handleRevocationRequest(ctx, req, resp) + if err != nil { + c.logger.Error("failed to process the response", "error", err) + return nil, err + } + + // If this is a revocation request, do not go through cache logic. + if isRevocation { + return resp, nil + } + + // Fast path for responses with no secrets + if secret == nil { + c.logger.Debug("pass-through response; no secret in response", "path", req.Request.URL.Path, "method", req.Request.Method) + return resp, nil + } + + // Short-circuit if the secret is not renewable + tokenRenewable, err := secret.TokenIsRenewable() + if err != nil { + c.logger.Error("failed to parse renewable param", "error", err) + return nil, err + } + if !secret.Renewable && !tokenRenewable { + c.logger.Debug("pass-through response; secret not renewable", "path", req.Request.URL.Path, "method", req.Request.Method) + return resp, nil + } + + var renewCtxInfo *ContextInfo + switch { + case secret.LeaseID != "": + c.logger.Debug("processing lease response", "path", req.Request.URL.Path, "method", req.Request.Method) + entry, err := c.db.Get(cachememdb.IndexNameToken, req.Token) + if err != nil { + return nil, err + } + // If the lease belongs to a token that is not managed by the agent, + // return the response without caching it. + if entry == nil { + c.logger.Debug("pass-through lease response; token not managed by agent", "path", req.Request.URL.Path, "method", req.Request.Method) + return resp, nil + } + + // Derive a context for renewal using the token's context + newCtxInfo := new(ContextInfo) + newCtxInfo.Ctx, newCtxInfo.CancelFunc = context.WithCancel(entry.RenewCtxInfo.Ctx) + newCtxInfo.DoneCh = make(chan struct{}) + renewCtxInfo = newCtxInfo + + index.Lease = secret.LeaseID + index.LeaseToken = req.Token + + case secret.Auth != nil: + c.logger.Debug("processing auth response", "path", req.Request.URL.Path, "method", req.Request.Method) + isNonOrphanNewToken := strings.HasPrefix(req.Request.URL.Path, vaultPathTokenCreate) && resp.Response.StatusCode == http.StatusOK && !secret.Auth.Orphan + + // If the new token is a result of token creation endpoints (not from + // login endpoints), and if its a non-orphan, then the new token's + // context should be derived from the context of the parent token. + var parentCtx context.Context + if isNonOrphanNewToken { + entry, err := c.db.Get(cachememdb.IndexNameToken, req.Token) + if err != nil { + return nil, err + } + // If parent token is not managed by the agent, child shouldn't be + // either. + if entry == nil { + c.logger.Debug("pass-through auth response; parent token not managed by agent", "path", req.Request.URL.Path, "method", req.Request.Method) + return resp, nil + } + + c.logger.Debug("setting parent context", "path", req.Request.URL.Path, "method", req.Request.Method) + parentCtx = entry.RenewCtxInfo.Ctx + + entry.TokenParent = req.Token + } + + renewCtxInfo = c.createCtxInfo(parentCtx, secret.Auth.ClientToken) + index.Token = secret.Auth.ClientToken + index.TokenAccessor = secret.Auth.Accessor + + default: + // We shouldn't be hitting this, but will err on the side of caution and + // simply proxy. + c.logger.Debug("pass-through response; secret without lease and token", "path", req.Request.URL.Path, "method", req.Request.Method) + return resp, nil + } + + // Serialize the response to store it in the cached index + var respBytes bytes.Buffer + err = resp.Response.Write(&respBytes) + if err != nil { + c.logger.Error("failed to serialize response", "error", err) + return nil, err + } + + // Reset the response body for upper layers to read + if resp.Response.Body != nil { + resp.Response.Body.Close() + } + resp.Response.Body = ioutil.NopCloser(bytes.NewBuffer(resp.ResponseBody)) + + // Set the index's Response + index.Response = respBytes.Bytes() + + // Store the index ID in the renewer context + renewCtx := context.WithValue(renewCtxInfo.Ctx, contextIndexID, index.ID) + + // Store the renewer context in the index + index.RenewCtxInfo = &cachememdb.ContextInfo{ + Ctx: renewCtx, + CancelFunc: renewCtxInfo.CancelFunc, + DoneCh: renewCtxInfo.DoneCh, + } + + // Store the index in the cache + c.logger.Debug("storing response into the cache", "path", req.Request.URL.Path, "method", req.Request.Method) + err = c.db.Set(index) + if err != nil { + c.logger.Error("failed to cache the proxied response", "error", err) + return nil, err + } + + // Start renewing the secret in the response + go c.startRenewing(renewCtx, index, req, secret) + + return resp, nil +} + +func (c *LeaseCache) createCtxInfo(ctx context.Context, token string) *ContextInfo { + if ctx == nil { + ctx = c.baseCtxInfo.Ctx + } + ctxInfo := new(ContextInfo) + ctxInfo.Ctx, ctxInfo.CancelFunc = context.WithCancel(ctx) + ctxInfo.DoneCh = make(chan struct{}) + return ctxInfo +} + +func (c *LeaseCache) startRenewing(ctx context.Context, index *cachememdb.Index, req *SendRequest, secret *api.Secret) { + defer func() { + id := ctx.Value(contextIndexID).(string) + c.logger.Debug("evicting index from cache", "id", id, "path", req.Request.URL.Path, "method", req.Request.Method) + err := c.db.Evict(cachememdb.IndexNameID, id) + if err != nil { + c.logger.Error("failed to evict index", "id", id, "error", err) + return + } + }() + + client, err := api.NewClient(api.DefaultConfig()) + if err != nil { + c.logger.Error("failed to create API client in the renewer", "error", err) + return + } + client.SetToken(req.Token) + client.SetHeaders(req.Request.Header) + + renewer, err := client.NewRenewer(&api.RenewerInput{ + Secret: secret, + }) + if err != nil { + c.logger.Error("failed to create secret renewer", "error", err) + return + } + + c.logger.Debug("initiating renewal", "path", req.Request.URL.Path, "method", req.Request.Method) + go renewer.Renew() + defer renewer.Stop() + + for { + select { + case <-ctx.Done(): + // This is the case which captures context cancellations from token + // and leases. Since all the contexts are derived from the agent's + // context, this will also cover the shutdown scenario. + c.logger.Debug("context cancelled; stopping renewer", "path", req.Request.URL.Path) + return + case err := <-renewer.DoneCh(): + // This case covers renewal completion and renewal errors + if err != nil { + c.logger.Error("failed to renew secret", "error", err) + return + } + c.logger.Debug("renewal halted; evicting from cache", "path", req.Request.URL.Path) + return + case renewal := <-renewer.RenewCh(): + // This case captures secret renewals. Renewed secret is updated in + // the cached index. + c.logger.Debug("renewal received; updating cache", "path", req.Request.URL.Path) + err = c.updateResponse(ctx, renewal) + if err != nil { + c.logger.Error("failed to handle renewal", "error", err) + return + } + case <-index.RenewCtxInfo.DoneCh: + // This case indicates the renewal process to shutdown and evict + // the cache entry. This is triggered when a specific secret + // renewal needs to be killed without affecting any of the derived + // context renewals. + c.logger.Debug("done channel closed") + return + } + } +} + +func (c *LeaseCache) updateResponse(ctx context.Context, renewal *api.RenewOutput) error { + id := ctx.Value(contextIndexID).(string) + + // Get the cached index using the id in the context + index, err := c.db.Get(cachememdb.IndexNameID, id) + if err != nil { + return err + } + if index == nil { + return fmt.Errorf("missing cache entry for id: %q", id) + } + + // Read the response from the index + resp, err := http.ReadResponse(bufio.NewReader(bytes.NewReader(index.Response)), nil) + if err != nil { + c.logger.Error("failed to deserialize response", "error", err) + return err + } + + // Update the body in the reponse by the renewed secret + bodyBytes, err := json.Marshal(renewal.Secret) + if err != nil { + return err + } + if resp.Body != nil { + resp.Body.Close() + } + resp.Body = ioutil.NopCloser(bytes.NewBuffer(bodyBytes)) + resp.ContentLength = int64(len(bodyBytes)) + + // Serialize the response + var respBytes bytes.Buffer + err = resp.Write(&respBytes) + if err != nil { + c.logger.Error("failed to serialize updated response", "error", err) + return err + } + + // Update the response in the index and set it in the cache + index.Response = respBytes.Bytes() + err = c.db.Set(index) + if err != nil { + c.logger.Error("failed to cache the proxied response", "error", err) + return err + } + + return nil +} + +// computeIndexID results in a value that uniquely identifies a request +// received by the agent. It does so by SHA256 hashing the serialized request +// object containing the request path, query parameters and body parameters. +func computeIndexID(req *SendRequest) (string, error) { + var b bytes.Buffer + + // Serialze the request + if err := req.Request.Write(&b); err != nil { + return "", fmt.Errorf("failed to serialize request: %v", err) + } + + // Reset the request body after it has been closed by Write + if req.Request.Body != nil { + req.Request.Body.Close() + } + req.Request.Body = ioutil.NopCloser(bytes.NewBuffer(req.RequestBody)) + + // Append req.Token into the byte slice. This is needed since auto-auth'ed + // requests sets the token directly into SendRequest.Token + b.Write([]byte(req.Token)) + + sum := sha256.Sum256(b.Bytes()) + return hex.EncodeToString(sum[:]), nil +} + +// HandleCacheClear returns a handlerFunc that can perform cache clearing operations. +func (c *LeaseCache) HandleCacheClear(ctx context.Context) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + req := new(cacheClearRequest) + if err := jsonutil.DecodeJSONFromReader(r.Body, req); err != nil { + if err == io.EOF { + err = errors.New("empty JSON provided") + } + respondError(w, http.StatusBadRequest, errwrap.Wrapf("failed to parse JSON input: {{err}}", err)) + return + } + + c.logger.Debug("received cache-clear request", "type", req.Type, "namespace", req.Namespace, "value", req.Value) + + if err := c.handleCacheClear(ctx, req.Type, req.Namespace, req.Value); err != nil { + // Default to 500 on error, unless the user provided an invalid type, + // which would then be a 400. + httpStatus := http.StatusInternalServerError + if err == errInvalidType { + httpStatus = http.StatusBadRequest + } + respondError(w, httpStatus, errwrap.Wrapf("failed to clear cache: {{err}}", err)) + return + } + + return + }) +} + +func (c *LeaseCache) handleCacheClear(ctx context.Context, clearType string, clearValues ...interface{}) error { + if len(clearValues) == 0 { + return errors.New("no value(s) provided to clear corresponding cache entries") + } + + // The value that we want to clear, for most cases, is the last one provided. + clearValue, ok := clearValues[len(clearValues)-1].(string) + if !ok { + return fmt.Errorf("unable to convert %v to type string", clearValue) + } + + switch clearType { + case "request_path": + // For this particular case, we need to ensure that there are 2 provided + // indexers for the proper lookup. + if len(clearValues) != 2 { + return fmt.Errorf("clearing cache by request path requires 2 indexers, got %d", len(clearValues)) + } + + // The first value provided for this case will be the namespace, but if it's + // an empty value we need to overwrite it with "root/" to ensure proper + // cache lookup. + if clearValues[0].(string) == "" { + clearValues[0] = "root/" + } + + // Find all the cached entries which has the given request path and + // cancel the contexts of all the respective renewers + indexes, err := c.db.GetByPrefix(clearType, clearValues...) + if err != nil { + return err + } + for _, index := range indexes { + index.RenewCtxInfo.CancelFunc() + } + + case "token": + if clearValue == "" { + return nil + } + + // Get the context for the given token and cancel its context + index, err := c.db.Get(cachememdb.IndexNameToken, clearValue) + if err != nil { + return err + } + if index == nil { + return nil + } + + c.logger.Debug("cancelling context of index attached to token") + + index.RenewCtxInfo.CancelFunc() + + case "token_accessor", "lease": + // Get the cached index and cancel the corresponding renewer context + index, err := c.db.Get(clearType, clearValue) + if err != nil { + return err + } + if index == nil { + return nil + } + + c.logger.Debug("cancelling context of index attached to accessor") + + index.RenewCtxInfo.CancelFunc() + + case "all": + // Cancel the base context which triggers all the goroutines to + // stop and evict entries from cache. + c.logger.Debug("cancelling base context") + c.baseCtxInfo.CancelFunc() + + // Reset the base context + baseCtx, baseCancel := context.WithCancel(ctx) + c.baseCtxInfo = &ContextInfo{ + Ctx: baseCtx, + CancelFunc: baseCancel, + } + + // Reset the memdb instance + if err := c.db.Flush(); err != nil { + return err + } + + default: + return errInvalidType + } + + c.logger.Debug("successfully cleared matching cache entries") + + return nil +} + +// handleRevocationRequest checks whether the originating request is a +// revocation request, and if so perform applicable cache cleanups. +// Returns true is this is a revocation request. +func (c *LeaseCache) handleRevocationRequest(ctx context.Context, req *SendRequest, resp *SendResponse) (bool, error) { + // Lease and token revocations return 204's on success. Fast-path if that's + // not the case. + if resp.Response.StatusCode != http.StatusNoContent { + return false, nil + } + + _, path := deriveNamespaceAndRevocationPath(req) + + switch { + case path == vaultPathTokenRevoke: + // Get the token from the request body + jsonBody := map[string]interface{}{} + if err := json.Unmarshal(req.RequestBody, &jsonBody); err != nil { + return false, err + } + tokenRaw, ok := jsonBody["token"] + if !ok { + return false, fmt.Errorf("failed to get token from request body") + } + token, ok := tokenRaw.(string) + if !ok { + return false, fmt.Errorf("expected token in the request body to be string") + } + + // Clear the cache entry associated with the token and all the other + // entries belonging to the leases derived from this token. + if err := c.handleCacheClear(ctx, "token", token); err != nil { + return false, err + } + + case path == vaultPathTokenRevokeSelf: + // Clear the cache entry associated with the token and all the other + // entries belonging to the leases derived from this token. + if err := c.handleCacheClear(ctx, "token", req.Token); err != nil { + return false, err + } + + case path == vaultPathTokenRevokeAccessor: + jsonBody := map[string]interface{}{} + if err := json.Unmarshal(req.RequestBody, &jsonBody); err != nil { + return false, err + } + accessorRaw, ok := jsonBody["accessor"] + if !ok { + return false, fmt.Errorf("failed to get accessor from request body") + } + accessor, ok := accessorRaw.(string) + if !ok { + return false, fmt.Errorf("expected accessor in the request body to be string") + } + + if err := c.handleCacheClear(ctx, "token_accessor", accessor); err != nil { + return false, err + } + + case path == vaultPathTokenRevokeOrphan: + jsonBody := map[string]interface{}{} + if err := json.Unmarshal(req.RequestBody, &jsonBody); err != nil { + return false, err + } + tokenRaw, ok := jsonBody["token"] + if !ok { + return false, fmt.Errorf("failed to get token from request body") + } + token, ok := tokenRaw.(string) + if !ok { + return false, fmt.Errorf("expected token in the request body to be string") + } + + // Kill the renewers of all the leases attached to the revoked token + indexes, err := c.db.GetByPrefix(cachememdb.IndexNameLeaseToken, token) + if err != nil { + return false, err + } + for _, index := range indexes { + index.RenewCtxInfo.CancelFunc() + } + + // Kill the renewer of the revoked token + index, err := c.db.Get(cachememdb.IndexNameToken, token) + if err != nil { + return false, err + } + if index == nil { + return true, nil + } + + // Indicate the renewer goroutine for this index to return. This will + // not affect the child tokens because the context is not getting + // cancelled. + close(index.RenewCtxInfo.DoneCh) + + // Clear the parent references of the revoked token in the entries + // belonging to the child tokens of the revoked token. + indexes, err = c.db.GetByPrefix(cachememdb.IndexNameTokenParent, token) + if err != nil { + return false, err + } + for _, index := range indexes { + index.TokenParent = "" + err = c.db.Set(index) + if err != nil { + c.logger.Error("failed to persist index", "error", err) + return false, err + } + } + + case path == vaultPathLeaseRevoke: + // TODO: Should lease present in the URL itself be considered here? + // Get the lease from the request body + jsonBody := map[string]interface{}{} + if err := json.Unmarshal(req.RequestBody, &jsonBody); err != nil { + return false, err + } + leaseIDRaw, ok := jsonBody["lease_id"] + if !ok { + return false, fmt.Errorf("failed to get lease_id from request body") + } + leaseID, ok := leaseIDRaw.(string) + if !ok { + return false, fmt.Errorf("expected lease_id the request body to be string") + } + if err := c.handleCacheClear(ctx, "lease", leaseID); err != nil { + return false, err + } + + case strings.HasPrefix(path, vaultPathLeaseRevokeForce): + // Trim the URL path to get the request path prefix + prefix := strings.TrimPrefix(path, vaultPathLeaseRevokeForce) + // Get all the cache indexes that use the request path containing the + // prefix and cancel the renewer context of each. + indexes, err := c.db.GetByPrefix(cachememdb.IndexNameLease, prefix) + if err != nil { + return false, err + } + + _, tokenNSID := namespace.SplitIDFromString(req.Token) + for _, index := range indexes { + _, leaseNSID := namespace.SplitIDFromString(index.Lease) + // Only evict leases that match the token's namespace + if tokenNSID == leaseNSID { + index.RenewCtxInfo.CancelFunc() + } + } + + case strings.HasPrefix(path, vaultPathLeaseRevokePrefix): + // Trim the URL path to get the request path prefix + prefix := strings.TrimPrefix(path, vaultPathLeaseRevokePrefix) + // Get all the cache indexes that use the request path containing the + // prefix and cancel the renewer context of each. + indexes, err := c.db.GetByPrefix(cachememdb.IndexNameLease, prefix) + if err != nil { + return false, err + } + + _, tokenNSID := namespace.SplitIDFromString(req.Token) + for _, index := range indexes { + _, leaseNSID := namespace.SplitIDFromString(index.Lease) + // Only evict leases that match the token's namespace + if tokenNSID == leaseNSID { + index.RenewCtxInfo.CancelFunc() + } + } + + default: + return false, nil + } + + c.logger.Debug("triggered caching eviction from revocation request") + + return true, nil +} + +// deriveNamespaceAndRevocationPath returns the namespace and relative path for +// revocation paths. +// +// If the path contains a namespace, but it's not a revocation path, it will be +// returned as-is, since there's no way to tell where the namespace ends and +// where the request path begins purely based off a string. +// +// Case 1: /v1/ns1/leases/revoke -> ns1/, /v1/leases/revoke +// Case 2: ns1/ /v1/leases/revoke -> ns1/, /v1/leases/revoke +// Case 3: /v1/ns1/foo/bar -> root/, /v1/ns1/foo/bar +// Case 4: ns1/ /v1/foo/bar -> ns1/, /v1/foo/bar +func deriveNamespaceAndRevocationPath(req *SendRequest) (string, string) { + namespace := "root/" + nsHeader := req.Request.Header.Get(consts.NamespaceHeaderName) + if nsHeader != "" { + namespace = nsHeader + } + + fullPath := req.Request.URL.Path + nonVersionedPath := strings.TrimPrefix(fullPath, "/v1") + + for _, pathToCheck := range revocationPaths { + // We use strings.Contains here for paths that can contain + // vars in the path, e.g. /v1/lease/revoke-prefix/:prefix + i := strings.Index(nonVersionedPath, pathToCheck) + // If there's no match, move on to the next check + if i == -1 { + continue + } + + // If the index is 0, this is a relative path with no namespace preppended, + // so we can break early + if i == 0 { + break + } + + // We need to turn /ns1 into ns1/, this makes it easy + namespaceInPath := nshelper.Canonicalize(nonVersionedPath[:i]) + + // If it's root, we replace, otherwise we join + if namespace == "root/" { + namespace = namespaceInPath + } else { + namespace = namespace + namespaceInPath + } + + return namespace, fmt.Sprintf("/v1%s", nonVersionedPath[i:]) + } + + return namespace, fmt.Sprintf("/v1%s", nonVersionedPath) +} diff --git a/command/agent/cache/lease_cache_test.go b/command/agent/cache/lease_cache_test.go new file mode 100644 index 000000000000..a455944da738 --- /dev/null +++ b/command/agent/cache/lease_cache_test.go @@ -0,0 +1,507 @@ +package cache + +import ( + "context" + "fmt" + "io/ioutil" + "net/http" + "net/http/httptest" + "net/url" + "reflect" + "strings" + "testing" + + "github.com/go-test/deep" + hclog "github.com/hashicorp/go-hclog" + "github.com/hashicorp/vault/api" + "github.com/hashicorp/vault/helper/consts" + "github.com/hashicorp/vault/helper/logging" +) + +func testNewLeaseCache(t *testing.T, responses []*SendResponse) *LeaseCache { + t.Helper() + + lc, err := NewLeaseCache(&LeaseCacheConfig{ + BaseContext: context.Background(), + Proxier: newMockProxier(responses), + Logger: logging.NewVaultLogger(hclog.Trace).Named("cache.leasecache"), + }) + + if err != nil { + t.Fatal(err) + } + return lc +} + +func TestCache_ComputeIndexID(t *testing.T) { + type args struct { + req *http.Request + } + tests := []struct { + name string + req *SendRequest + want string + wantErr bool + }{ + { + "basic", + &SendRequest{ + Request: &http.Request{ + URL: &url.URL{ + Path: "test", + }, + }, + }, + "2edc7e965c3e1bdce3b1d5f79a52927842569c0734a86544d222753f11ae4847", + false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := computeIndexID(tt.req) + if (err != nil) != tt.wantErr { + t.Errorf("actual_error: %v, expected_error: %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, string(tt.want)) { + t.Errorf("bad: index id; actual: %q, expected: %q", got, string(tt.want)) + } + }) + } +} + +func TestCache_LeaseCache_EmptyToken(t *testing.T) { + responses := []*SendResponse{ + &SendResponse{ + Response: &api.Response{ + Response: &http.Response{ + StatusCode: http.StatusCreated, + Body: ioutil.NopCloser(strings.NewReader(`{"value": "invalid", "auth": {"client_token": "testtoken"}}`)), + }, + }, + ResponseBody: []byte(`{"value": "invalid", "auth": {"client_token": "testtoken"}}`), + }, + } + lc := testNewLeaseCache(t, responses) + + // Even if the send request doesn't have a token on it, a successful + // cacheable response should result in the index properly getting populated + // with a token and memdb shouldn't complain while inserting the index. + urlPath := "http://example.com/v1/sample/api" + sendReq := &SendRequest{ + Request: httptest.NewRequest("GET", urlPath, strings.NewReader(`{"value": "input"}`)), + } + resp, err := lc.Send(context.Background(), sendReq) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatalf("expected a non empty response") + } +} + +func TestCache_LeaseCache_SendCacheable(t *testing.T) { + // Emulate 2 responses from the api proxy. One returns a new token and the + // other returns a lease. + responses := []*SendResponse{ + &SendResponse{ + Response: &api.Response{ + Response: &http.Response{ + StatusCode: http.StatusCreated, + Body: ioutil.NopCloser(strings.NewReader(`{"value": "invalid", "auth": {"client_token": "testtoken", "renewable": true}}`)), + }, + }, + ResponseBody: []byte(`{"value": "invalid", "auth": {"client_token": "testtoken", "renewable": true}}`), + }, + &SendResponse{ + Response: &api.Response{ + Response: &http.Response{ + StatusCode: http.StatusOK, + Body: ioutil.NopCloser(strings.NewReader(`{"value": "output", "lease_id": "foo", "renewable": true}`)), + }, + }, + ResponseBody: []byte(`{"value": "output", "lease_id": "foo", "renewable": true}`), + }, + } + lc := testNewLeaseCache(t, responses) + + // Make a request. A response with a new token is returned to the lease + // cache and that will be cached. + urlPath := "http://example.com/v1/sample/api" + sendReq := &SendRequest{ + Request: httptest.NewRequest("GET", urlPath, strings.NewReader(`{"value": "input"}`)), + } + resp, err := lc.Send(context.Background(), sendReq) + if err != nil { + t.Fatal(err) + } + if diff := deep.Equal(resp.Response.StatusCode, responses[0].Response.StatusCode); diff != nil { + t.Fatalf("expected getting proxied response: got %v", diff) + } + + // Send the same request again to get the cached response + sendReq = &SendRequest{ + Request: httptest.NewRequest("GET", urlPath, strings.NewReader(`{"value": "input"}`)), + } + resp, err = lc.Send(context.Background(), sendReq) + if err != nil { + t.Fatal(err) + } + if diff := deep.Equal(resp.Response.StatusCode, responses[0].Response.StatusCode); diff != nil { + t.Fatalf("expected getting proxied response: got %v", diff) + } + + // Modify the request a little bit to ensure the second response is + // returned to the lease cache. But make sure that the token in the request + // is valid. + sendReq = &SendRequest{ + Token: "testtoken", + Request: httptest.NewRequest("GET", urlPath, strings.NewReader(`{"value": "input_changed"}`)), + } + resp, err = lc.Send(context.Background(), sendReq) + if err != nil { + t.Fatal(err) + } + if diff := deep.Equal(resp.Response.StatusCode, responses[1].Response.StatusCode); diff != nil { + t.Fatalf("expected getting proxied response: got %v", diff) + } + + // Make the same request again and ensure that the same reponse is returned + // again. + sendReq = &SendRequest{ + Token: "testtoken", + Request: httptest.NewRequest("GET", urlPath, strings.NewReader(`{"value": "input_changed"}`)), + } + resp, err = lc.Send(context.Background(), sendReq) + if err != nil { + t.Fatal(err) + } + if diff := deep.Equal(resp.Response.StatusCode, responses[1].Response.StatusCode); diff != nil { + t.Fatalf("expected getting proxied response: got %v", diff) + } +} + +func TestCache_LeaseCache_SendNonCacheable(t *testing.T) { + responses := []*SendResponse{ + &SendResponse{ + Response: &api.Response{ + Response: &http.Response{ + StatusCode: http.StatusOK, + Body: ioutil.NopCloser(strings.NewReader(`{"value": "output"}`)), + }, + }, + }, + &SendResponse{ + Response: &api.Response{ + Response: &http.Response{ + StatusCode: http.StatusNotFound, + Body: ioutil.NopCloser(strings.NewReader(`{"value": "invalid"}`)), + }, + }, + }, + } + lc := testNewLeaseCache(t, responses) + + // Send a request through the lease cache which is not cacheable (there is + // no lease information or auth information in the response) + sendReq := &SendRequest{ + Request: httptest.NewRequest("GET", "http://example.com", strings.NewReader(`{"value": "input"}`)), + } + resp, err := lc.Send(context.Background(), sendReq) + if err != nil { + t.Fatal(err) + } + if diff := deep.Equal(resp.Response.StatusCode, responses[0].Response.StatusCode); diff != nil { + t.Fatalf("expected getting proxied response: got %v", diff) + } + + // Since the response is non-cacheable, the second response will be + // returned. + sendReq = &SendRequest{ + Token: "foo", + Request: httptest.NewRequest("GET", "http://example.com", strings.NewReader(`{"value": "input"}`)), + } + resp, err = lc.Send(context.Background(), sendReq) + if err != nil { + t.Fatal(err) + } + if diff := deep.Equal(resp.Response.StatusCode, responses[1].Response.StatusCode); diff != nil { + t.Fatalf("expected getting proxied response: got %v", diff) + } +} + +func TestCache_LeaseCache_SendNonCacheableNonTokenLease(t *testing.T) { + // Create the cache + responses := []*SendResponse{ + &SendResponse{ + Response: &api.Response{ + Response: &http.Response{ + StatusCode: http.StatusOK, + Body: ioutil.NopCloser(strings.NewReader(`{"value": "output", "lease_id": "foo"}`)), + }, + }, + ResponseBody: []byte(`{"value": "output", "lease_id": "foo"}`), + }, + &SendResponse{ + Response: &api.Response{ + Response: &http.Response{ + StatusCode: http.StatusCreated, + Body: ioutil.NopCloser(strings.NewReader(`{"value": "invalid", "auth": {"client_token": "testtoken"}}`)), + }, + }, + ResponseBody: []byte(`{"value": "invalid", "auth": {"client_token": "testtoken"}}`), + }, + } + lc := testNewLeaseCache(t, responses) + + // Send a request through lease cache which returns a response containing + // lease_id. Response will not be cached because it doesn't belong to a + // token that is managed by the lease cache. + urlPath := "http://example.com/v1/sample/api" + sendReq := &SendRequest{ + Token: "foo", + Request: httptest.NewRequest("GET", urlPath, strings.NewReader(`{"value": "input"}`)), + } + resp, err := lc.Send(context.Background(), sendReq) + if err != nil { + t.Fatal(err) + } + if diff := deep.Equal(resp.Response.StatusCode, responses[0].Response.StatusCode); diff != nil { + t.Fatalf("expected getting proxied response: got %v", diff) + } + + // Verify that the response is not cached by sending the same request and + // by expecting a different response. + sendReq = &SendRequest{ + Token: "foo", + Request: httptest.NewRequest("GET", urlPath, strings.NewReader(`{"value": "input"}`)), + } + resp, err = lc.Send(context.Background(), sendReq) + if err != nil { + t.Fatal(err) + } + if diff := deep.Equal(resp.Response.StatusCode, responses[0].Response.StatusCode); diff == nil { + t.Fatalf("expected getting proxied response: got %v", diff) + } +} + +func TestCache_LeaseCache_HandleCacheClear(t *testing.T) { + lc := testNewLeaseCache(t, nil) + + handler := lc.HandleCacheClear(context.Background()) + ts := httptest.NewServer(handler) + defer ts.Close() + + // Test missing body, should return 400 + resp, err := http.Post(ts.URL, "application/json", nil) + if err != nil { + t.Fatal() + } + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("status code mismatch: expected = %v, got = %v", http.StatusBadRequest, resp.StatusCode) + } + + testCases := []struct { + name string + reqType string + reqValue string + expectedStatusCode int + }{ + { + "invalid_type", + "foo", + "", + http.StatusBadRequest, + }, + { + "invalid_value", + "", + "bar", + http.StatusBadRequest, + }, + { + "all", + "all", + "", + http.StatusOK, + }, + { + "by_request_path", + "request_path", + "foo", + http.StatusOK, + }, + { + "by_token", + "token", + "foo", + http.StatusOK, + }, + { + "by_lease", + "lease", + "foo", + http.StatusOK, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + reqBody := fmt.Sprintf("{\"type\": \"%s\", \"value\": \"%s\"}", tc.reqType, tc.reqValue) + resp, err := http.Post(ts.URL, "application/json", strings.NewReader(reqBody)) + if err != nil { + t.Fatal(err) + } + if tc.expectedStatusCode != resp.StatusCode { + t.Fatalf("status code mismatch: expected = %v, got = %v", tc.expectedStatusCode, resp.StatusCode) + } + }) + } +} + +func TestCache_DeriveNamespaceAndRevocationPath(t *testing.T) { + tests := []struct { + name string + req *SendRequest + wantNamespace string + wantRelativePath string + }{ + { + "non_revocation_full_path", + &SendRequest{ + Request: &http.Request{ + URL: &url.URL{ + Path: "/v1/ns1/sys/mounts", + }, + }, + }, + "root/", + "/v1/ns1/sys/mounts", + }, + { + "non_revocation_relative_path", + &SendRequest{ + Request: &http.Request{ + URL: &url.URL{ + Path: "/v1/sys/mounts", + }, + Header: http.Header{ + consts.NamespaceHeaderName: []string{"ns1/"}, + }, + }, + }, + "ns1/", + "/v1/sys/mounts", + }, + { + "non_revocation_relative_path", + &SendRequest{ + Request: &http.Request{ + URL: &url.URL{ + Path: "/v1/ns2/sys/mounts", + }, + Header: http.Header{ + consts.NamespaceHeaderName: []string{"ns1/"}, + }, + }, + }, + "ns1/", + "/v1/ns2/sys/mounts", + }, + { + "revocation_full_path", + &SendRequest{ + Request: &http.Request{ + URL: &url.URL{ + Path: "/v1/ns1/sys/leases/revoke", + }, + }, + }, + "ns1/", + "/v1/sys/leases/revoke", + }, + { + "revocation_relative_path", + &SendRequest{ + Request: &http.Request{ + URL: &url.URL{ + Path: "/v1/sys/leases/revoke", + }, + Header: http.Header{ + consts.NamespaceHeaderName: []string{"ns1/"}, + }, + }, + }, + "ns1/", + "/v1/sys/leases/revoke", + }, + { + "revocation_relative_partial_ns", + &SendRequest{ + Request: &http.Request{ + URL: &url.URL{ + Path: "/v1/ns2/sys/leases/revoke", + }, + Header: http.Header{ + consts.NamespaceHeaderName: []string{"ns1/"}, + }, + }, + }, + "ns1/ns2/", + "/v1/sys/leases/revoke", + }, + { + "revocation_prefix_full_path", + &SendRequest{ + Request: &http.Request{ + URL: &url.URL{ + Path: "/v1/ns1/sys/leases/revoke-prefix/foo", + }, + }, + }, + "ns1/", + "/v1/sys/leases/revoke-prefix/foo", + }, + { + "revocation_prefix_relative_path", + &SendRequest{ + Request: &http.Request{ + URL: &url.URL{ + Path: "/v1/sys/leases/revoke-prefix/foo", + }, + Header: http.Header{ + consts.NamespaceHeaderName: []string{"ns1/"}, + }, + }, + }, + "ns1/", + "/v1/sys/leases/revoke-prefix/foo", + }, + { + "revocation_prefix_partial_ns", + &SendRequest{ + Request: &http.Request{ + URL: &url.URL{ + Path: "/v1/ns2/sys/leases/revoke-prefix/foo", + }, + Header: http.Header{ + consts.NamespaceHeaderName: []string{"ns1/"}, + }, + }, + }, + "ns1/ns2/", + "/v1/sys/leases/revoke-prefix/foo", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotNamespace, gotRelativePath := deriveNamespaceAndRevocationPath(tt.req) + if gotNamespace != tt.wantNamespace { + t.Errorf("deriveNamespaceAndRevocationPath() gotNamespace = %v, want %v", gotNamespace, tt.wantNamespace) + } + if gotRelativePath != tt.wantRelativePath { + t.Errorf("deriveNamespaceAndRevocationPath() gotRelativePath = %v, want %v", gotRelativePath, tt.wantRelativePath) + } + }) + } +} diff --git a/command/agent/cache/listener.go b/command/agent/cache/listener.go new file mode 100644 index 000000000000..1adca7a8dc4b --- /dev/null +++ b/command/agent/cache/listener.go @@ -0,0 +1,105 @@ +package cache + +import ( + "fmt" + "io" + "net" + "os" + "strings" + + "github.com/hashicorp/vault/command/agent/config" + "github.com/hashicorp/vault/command/server" + "github.com/hashicorp/vault/helper/reload" + "github.com/mitchellh/cli" +) + +func ServerListener(lnConfig *config.Listener, logger io.Writer, ui cli.Ui) (net.Listener, map[string]string, reload.ReloadFunc, error) { + switch lnConfig.Type { + case "unix": + return unixSocketListener(lnConfig.Config, logger, ui) + case "tcp": + return tcpListener(lnConfig.Config, logger, ui) + default: + return nil, nil, nil, fmt.Errorf("unsupported listener type: %q", lnConfig.Type) + } +} + +func unixSocketListener(config map[string]interface{}, _ io.Writer, ui cli.Ui) (net.Listener, map[string]string, reload.ReloadFunc, error) { + addr, ok := config["address"].(string) + if !ok { + return nil, nil, nil, fmt.Errorf("invalid address: %v", config["address"]) + } + + if addr == "" { + return nil, nil, nil, fmt.Errorf("address field should point to socket file path") + } + + // Remove the socket file as it shouldn't exist for the domain socket to + // work + err := os.Remove(addr) + if err != nil && !os.IsNotExist(err) { + return nil, nil, nil, fmt.Errorf("failed to remove the socket file: %v", err) + } + + listener, err := net.Listen("unix", addr) + if err != nil { + return nil, nil, nil, err + } + + // Wrap the listener in rmListener so that the Unix domain socket file is + // removed on close. + listener = &rmListener{ + Listener: listener, + Path: addr, + } + + props := map[string]string{"addr": addr, "tls": "disabled"} + + return listener, props, nil, nil +} + +func tcpListener(config map[string]interface{}, _ io.Writer, ui cli.Ui) (net.Listener, map[string]string, reload.ReloadFunc, error) { + bindProto := "tcp" + var addr string + addrRaw, ok := config["address"] + if !ok { + addr = "127.0.0.1:8300" + } else { + addr = addrRaw.(string) + } + + // If they've passed 0.0.0.0, we only want to bind on IPv4 + // rather than golang's dual stack default + if strings.HasPrefix(addr, "0.0.0.0:") { + bindProto = "tcp4" + } + + ln, err := net.Listen(bindProto, addr) + if err != nil { + return nil, nil, nil, err + } + + ln = server.TCPKeepAliveListener{ln.(*net.TCPListener)} + + props := map[string]string{"addr": addr} + + return server.ListenerWrapTLS(ln, props, config, ui) +} + +// rmListener is an implementation of net.Listener that forwards most +// calls to the listener but also removes a file as part of the close. We +// use this to cleanup the unix domain socket on close. +type rmListener struct { + net.Listener + Path string +} + +func (l *rmListener) Close() error { + // Close the listener itself + if err := l.Listener.Close(); err != nil { + return err + } + + // Remove the file + return os.Remove(l.Path) +} diff --git a/command/agent/cache/proxy.go b/command/agent/cache/proxy.go new file mode 100644 index 000000000000..4637590917e9 --- /dev/null +++ b/command/agent/cache/proxy.go @@ -0,0 +1,28 @@ +package cache + +import ( + "context" + "net/http" + + "github.com/hashicorp/vault/api" +) + +// SendRequest is the input for Proxier.Send. +type SendRequest struct { + Token string + Request *http.Request + RequestBody []byte +} + +// SendResponse is the output from Proxier.Send. +type SendResponse struct { + Response *api.Response + ResponseBody []byte +} + +// Proxier is the interface implemented by different components that are +// responsible for performing specific tasks, such as caching and proxying. All +// these tasks combined together would serve the request received by the agent. +type Proxier interface { + Send(ctx context.Context, req *SendRequest) (*SendResponse, error) +} diff --git a/command/agent/cache/testing.go b/command/agent/cache/testing.go new file mode 100644 index 000000000000..d9de1caadc7d --- /dev/null +++ b/command/agent/cache/testing.go @@ -0,0 +1,36 @@ +package cache + +import ( + "context" + "fmt" +) + +// mockProxier is a mock implementation of the Proxier interface, used for testing purposes. +// The mock will return the provided responses every time it reaches its Send method, up to +// the last provided response. This lets tests control what the next/underlying Proxier layer +// might expect to return. +type mockProxier struct { + proxiedResponses []*SendResponse + responseIndex int +} + +func newMockProxier(responses []*SendResponse) *mockProxier { + return &mockProxier{ + proxiedResponses: responses, + } +} + +func (p *mockProxier) Send(ctx context.Context, req *SendRequest) (*SendResponse, error) { + if p.responseIndex >= len(p.proxiedResponses) { + return nil, fmt.Errorf("index out of bounds: responseIndex = %d, responses = %d", p.responseIndex, len(p.proxiedResponses)) + } + resp := p.proxiedResponses[p.responseIndex] + + p.responseIndex++ + + return resp, nil +} + +func (p *mockProxier) ResponseIndex() int { + return p.responseIndex +} diff --git a/command/agent/cache_end_to_end_test.go b/command/agent/cache_end_to_end_test.go new file mode 100644 index 000000000000..88f1c36409f9 --- /dev/null +++ b/command/agent/cache_end_to_end_test.go @@ -0,0 +1,280 @@ +package agent + +import ( + "context" + "fmt" + "io/ioutil" + "net" + "net/http" + "os" + "testing" + "time" + + hclog "github.com/hashicorp/go-hclog" + log "github.com/hashicorp/go-hclog" + "github.com/hashicorp/vault/api" + credAppRole "github.com/hashicorp/vault/builtin/credential/approle" + "github.com/hashicorp/vault/command/agent/auth" + agentapprole "github.com/hashicorp/vault/command/agent/auth/approle" + "github.com/hashicorp/vault/command/agent/cache" + "github.com/hashicorp/vault/command/agent/sink" + "github.com/hashicorp/vault/command/agent/sink/file" + "github.com/hashicorp/vault/helper/logging" + vaulthttp "github.com/hashicorp/vault/http" + "github.com/hashicorp/vault/logical" + "github.com/hashicorp/vault/vault" +) + +func TestCache_UsingAutoAuthToken(t *testing.T) { + var err error + logger := logging.NewVaultLogger(log.Trace) + coreConfig := &vault.CoreConfig{ + DisableMlock: true, + DisableCache: true, + Logger: log.NewNullLogger(), + CredentialBackends: map[string]logical.Factory{ + "approle": credAppRole.Factory, + }, + } + + cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{ + HandlerFunc: vaulthttp.Handler, + }) + + cluster.Start() + defer cluster.Cleanup() + + cores := cluster.Cores + + vault.TestWaitActive(t, cores[0].Core) + + client := cores[0].Client + + defer os.Setenv(api.EnvVaultAddress, os.Getenv(api.EnvVaultAddress)) + os.Setenv(api.EnvVaultAddress, client.Address()) + + defer os.Setenv(api.EnvVaultCACert, os.Getenv(api.EnvVaultCACert)) + os.Setenv(api.EnvVaultCACert, fmt.Sprintf("%s/ca_cert.pem", cluster.TempDir)) + + err = client.Sys().EnableAuthWithOptions("approle", &api.EnableAuthOptions{ + Type: "approle", + }) + if err != nil { + t.Fatal(err) + } + + _, err = client.Logical().Write("auth/approle/role/test1", map[string]interface{}{ + "bind_secret_id": "true", + "token_ttl": "3s", + "token_max_ttl": "10s", + }) + if err != nil { + t.Fatal(err) + } + + resp, err := client.Logical().Write("auth/approle/role/test1/secret-id", nil) + if err != nil { + t.Fatal(err) + } + secretID1 := resp.Data["secret_id"].(string) + + resp, err = client.Logical().Read("auth/approle/role/test1/role-id") + if err != nil { + t.Fatal(err) + } + roleID1 := resp.Data["role_id"].(string) + + rolef, err := ioutil.TempFile("", "auth.role-id.test.") + if err != nil { + t.Fatal(err) + } + role := rolef.Name() + rolef.Close() // WriteFile doesn't need it open + defer os.Remove(role) + t.Logf("input role_id_file_path: %s", role) + + secretf, err := ioutil.TempFile("", "auth.secret-id.test.") + if err != nil { + t.Fatal(err) + } + secret := secretf.Name() + secretf.Close() + defer os.Remove(secret) + t.Logf("input secret_id_file_path: %s", secret) + + // We close these right away because we're just basically testing + // permissions and finding a usable file name + ouf, err := ioutil.TempFile("", "auth.tokensink.test.") + if err != nil { + t.Fatal(err) + } + out := ouf.Name() + ouf.Close() + os.Remove(out) + t.Logf("output: %s", out) + + ctx, cancelFunc := context.WithCancel(context.Background()) + timer := time.AfterFunc(30*time.Second, func() { + cancelFunc() + }) + defer timer.Stop() + + conf := map[string]interface{}{ + "role_id_file_path": role, + "secret_id_file_path": secret, + "remove_secret_id_file_after_reading": true, + } + + am, err := agentapprole.NewApproleAuthMethod(&auth.AuthConfig{ + Logger: logger.Named("auth.approle"), + MountPath: "auth/approle", + Config: conf, + }) + if err != nil { + t.Fatal(err) + } + ahConfig := &auth.AuthHandlerConfig{ + Logger: logger.Named("auth.handler"), + Client: client, + } + ah := auth.NewAuthHandler(ahConfig) + go ah.Run(ctx, am) + defer func() { + <-ah.DoneCh + }() + + config := &sink.SinkConfig{ + Logger: logger.Named("sink.file"), + Config: map[string]interface{}{ + "path": out, + }, + } + fs, err := file.NewFileSink(config) + if err != nil { + t.Fatal(err) + } + config.Sink = fs + + ss := sink.NewSinkServer(&sink.SinkServerConfig{ + Logger: logger.Named("sink.server"), + Client: client, + }) + go ss.Run(ctx, ah.OutputCh, []*sink.SinkConfig{config}) + defer func() { + <-ss.DoneCh + }() + + // This has to be after the other defers so it happens first + defer cancelFunc() + + // Check that no sink file exists + _, err = os.Lstat(out) + if err == nil { + t.Fatal("expected err") + } + if !os.IsNotExist(err) { + t.Fatal("expected notexist err") + } + + if err := ioutil.WriteFile(role, []byte(roleID1), 0600); err != nil { + t.Fatal(err) + } else { + logger.Trace("wrote test role 1", "path", role) + } + + if err := ioutil.WriteFile(secret, []byte(secretID1), 0600); err != nil { + t.Fatal(err) + } else { + logger.Trace("wrote test secret 1", "path", secret) + } + + getToken := func() string { + timeout := time.Now().Add(10 * time.Second) + for { + if time.Now().After(timeout) { + t.Fatal("did not find a written token after timeout") + } + val, err := ioutil.ReadFile(out) + if err == nil { + os.Remove(out) + if len(val) == 0 { + t.Fatal("written token was empty") + } + + _, err = os.Stat(secret) + if err == nil { + t.Fatal("secret file exists but was supposed to be removed") + } + + client.SetToken(string(val)) + _, err := client.Auth().Token().LookupSelf() + if err != nil { + t.Fatal(err) + } + return string(val) + } + time.Sleep(250 * time.Millisecond) + } + } + + t.Logf("auto-auth token: %q", getToken()) + + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + + defer listener.Close() + + cacheLogger := logging.NewVaultLogger(hclog.Trace).Named("cache") + + // Create the API proxier + apiProxy := cache.NewAPIProxy(&cache.APIProxyConfig{ + Logger: cacheLogger.Named("apiproxy"), + }) + + // Create the lease cache proxier and set its underlying proxier to + // the API proxier. + leaseCache, err := cache.NewLeaseCache(&cache.LeaseCacheConfig{ + BaseContext: ctx, + Proxier: apiProxy, + Logger: cacheLogger.Named("leasecache"), + }) + if err != nil { + t.Fatal(err) + } + + // Create a muxer and add paths relevant for the lease cache layer + mux := http.NewServeMux() + mux.Handle("/v1/agent/cache-clear", leaseCache.HandleCacheClear(ctx)) + + mux.Handle("/", cache.Handler(ctx, cacheLogger, leaseCache, true, client)) + server := &http.Server{ + Handler: mux, + ReadHeaderTimeout: 10 * time.Second, + ReadTimeout: 30 * time.Second, + IdleTimeout: 5 * time.Minute, + ErrorLog: cacheLogger.StandardLogger(nil), + } + go server.Serve(listener) + + testClient, err := api.NewClient(api.DefaultConfig()) + if err != nil { + t.Fatal(err) + } + + if err := testClient.SetAddress("http://" + listener.Addr().String()); err != nil { + t.Fatal(err) + } + + // Wait for listeners to come up + time.Sleep(2 * time.Second) + + resp, err = testClient.Logical().Read("auth/token/lookup-self") + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatalf("failed to use the auto-auth token to perform lookup-self") + } +} diff --git a/command/agent/config/config.go b/command/agent/config/config.go index 3a18b946efac..9c9a80aaf9b7 100644 --- a/command/agent/config/config.go +++ b/command/agent/config/config.go @@ -22,6 +22,17 @@ type Config struct { AutoAuth *AutoAuth `hcl:"auto_auth"` ExitAfterAuth bool `hcl:"exit_after_auth"` PidFile string `hcl:"pid_file"` + Cache *Cache `hcl:"cache"` +} + +type Cache struct { + UseAutoAuthToken bool `hcl:"use_auto_auth_token"` + Listeners []*Listener `hcl:"listeners"` +} + +type Listener struct { + Type string + Config map[string]interface{} } type AutoAuth struct { @@ -91,9 +102,102 @@ func LoadConfig(path string, logger log.Logger) (*Config, error) { return nil, errwrap.Wrapf("error parsing 'auto_auth': {{err}}", err) } + err = parseCache(&result, list) + if err != nil { + return nil, errwrap.Wrapf("error parsing 'cache':{{err}}", err) + } + return &result, nil } +func parseCache(result *Config, list *ast.ObjectList) error { + name := "cache" + + cacheList := list.Filter(name) + if len(cacheList.Items) == 0 { + return nil + } + + if len(cacheList.Items) > 1 { + return fmt.Errorf("one and only one %q block is required", name) + } + + item := cacheList.Items[0] + + var c Cache + err := hcl.DecodeObject(&c, item.Val) + if err != nil { + return err + } + + result.Cache = &c + + subs, ok := item.Val.(*ast.ObjectType) + if !ok { + return fmt.Errorf("could not parse %q as an object", name) + } + subList := subs.List + + err = parseListeners(result, subList) + if err != nil { + return errwrap.Wrapf("error parsing 'listener' stanzas: {{err}}", err) + } + + return nil +} + +func parseListeners(result *Config, list *ast.ObjectList) error { + name := "listener" + + listenerList := list.Filter(name) + if len(listenerList.Items) < 1 { + return fmt.Errorf("at least one %q block is required", name) + } + + var listeners []*Listener + for _, item := range listenerList.Items { + var lnConfig map[string]interface{} + err := hcl.DecodeObject(&lnConfig, item.Val) + if err != nil { + return err + } + + var lnType string + switch { + case lnConfig["type"] != nil: + lnType = lnConfig["type"].(string) + delete(lnConfig, "type") + case len(item.Keys) == 1: + lnType = strings.ToLower(item.Keys[0].Token.Value().(string)) + default: + return errors.New("listener type must be specified") + } + + switch lnType { + case "unix": + // Don't accept TLS connection information for unix domain socket + // listener. Maybe something to support in future. + unixLnConfig := map[string]interface{}{ + "tls_disable": true, + } + unixLnConfig["address"] = lnConfig["address"] + lnConfig = unixLnConfig + case "tcp": + default: + return fmt.Errorf("invalid listener type %q", lnType) + } + + listeners = append(listeners, &Listener{ + Type: lnType, + Config: lnConfig, + }) + } + + result.Cache.Listeners = listeners + + return nil +} + func parseAutoAuth(result *Config, list *ast.ObjectList) error { name := "auto_auth" diff --git a/command/agent/config/config_test.go b/command/agent/config/config_test.go index 2f78b4fb04fa..49621b50c153 100644 --- a/command/agent/config/config_test.go +++ b/command/agent/config/config_test.go @@ -10,6 +10,80 @@ import ( "github.com/hashicorp/vault/helper/logging" ) +func TestLoadConfigFile_AgentCache(t *testing.T) { + logger := logging.NewVaultLogger(log.Debug) + + config, err := LoadConfig("./test-fixtures/config-cache.hcl", logger) + if err != nil { + t.Fatal(err) + } + + expected := &Config{ + AutoAuth: &AutoAuth{ + Method: &Method{ + Type: "aws", + WrapTTL: 300 * time.Second, + MountPath: "auth/aws", + Config: map[string]interface{}{ + "role": "foobar", + }, + }, + Sinks: []*Sink{ + &Sink{ + Type: "file", + DHType: "curve25519", + DHPath: "/tmp/file-foo-dhpath", + AAD: "foobar", + Config: map[string]interface{}{ + "path": "/tmp/file-foo", + }, + }, + }, + }, + Cache: &Cache{ + UseAutoAuthToken: true, + Listeners: []*Listener{ + &Listener{ + Type: "unix", + Config: map[string]interface{}{ + "address": "/path/to/socket", + "tls_disable": true, + }, + }, + &Listener{ + Type: "tcp", + Config: map[string]interface{}{ + "address": "127.0.0.1:8300", + "tls_disable": true, + }, + }, + &Listener{ + Type: "tcp", + Config: map[string]interface{}{ + "address": "127.0.0.1:8400", + "tls_key_file": "/path/to/cakey.pem", + "tls_cert_file": "/path/to/cacert.pem", + }, + }, + }, + }, + PidFile: "./pidfile", + } + + if diff := deep.Equal(config, expected); diff != nil { + t.Fatal(diff) + } + + config, err = LoadConfig("./test-fixtures/config-cache-embedded-type.hcl", logger) + if err != nil { + t.Fatal(err) + } + + if diff := deep.Equal(config, expected); diff != nil { + t.Fatal(diff) + } +} + func TestLoadConfigFile(t *testing.T) { logger := logging.NewVaultLogger(log.Debug) diff --git a/command/agent/config/test-fixtures/config-cache-embedded-type.hcl b/command/agent/config/test-fixtures/config-cache-embedded-type.hcl new file mode 100644 index 000000000000..3079b29d7cdb --- /dev/null +++ b/command/agent/config/test-fixtures/config-cache-embedded-type.hcl @@ -0,0 +1,44 @@ +pid_file = "./pidfile" + +auto_auth { + method { + type = "aws" + wrap_ttl = 300 + config = { + role = "foobar" + } + } + + sink { + type = "file" + config = { + path = "/tmp/file-foo" + } + aad = "foobar" + dh_type = "curve25519" + dh_path = "/tmp/file-foo-dhpath" + } +} + +cache { + use_auto_auth_token = true + + listener { + type = "unix" + address = "/path/to/socket" + tls_disable = true + } + + listener { + type = "tcp" + address = "127.0.0.1:8300" + tls_disable = true + } + + listener { + type = "tcp" + address = "127.0.0.1:8400" + tls_key_file = "/path/to/cakey.pem" + tls_cert_file = "/path/to/cacert.pem" + } +} diff --git a/command/agent/config/test-fixtures/config-cache.hcl b/command/agent/config/test-fixtures/config-cache.hcl new file mode 100644 index 000000000000..f2ae5cb380c3 --- /dev/null +++ b/command/agent/config/test-fixtures/config-cache.hcl @@ -0,0 +1,41 @@ +pid_file = "./pidfile" + +auto_auth { + method { + type = "aws" + wrap_ttl = 300 + config = { + role = "foobar" + } + } + + sink { + type = "file" + config = { + path = "/tmp/file-foo" + } + aad = "foobar" + dh_type = "curve25519" + dh_path = "/tmp/file-foo-dhpath" + } +} + +cache { + use_auto_auth_token = true + + listener "unix" { + address = "/path/to/socket" + tls_disable = true + } + + listener "tcp" { + address = "127.0.0.1:8300" + tls_disable = true + } + + listener "tcp" { + address = "127.0.0.1:8400" + tls_key_file = "/path/to/cakey.pem" + tls_cert_file = "/path/to/cacert.pem" + } +} diff --git a/command/agent_test.go b/command/agent_test.go index 386ad47799b7..f08a13f58dd5 100644 --- a/command/agent_test.go +++ b/command/agent_test.go @@ -5,6 +5,7 @@ import ( "io/ioutil" "os" "testing" + "time" hclog "github.com/hashicorp/go-hclog" vaultjwt "github.com/hashicorp/vault-plugin-auth-jwt" @@ -30,6 +31,188 @@ func testAgentCommand(tb testing.TB, logger hclog.Logger) (*cli.MockUi, *AgentCo } } +func TestAgent_Cache_UnixListener(t *testing.T) { + logger := logging.NewVaultLogger(hclog.Trace) + coreConfig := &vault.CoreConfig{ + Logger: logger.Named("core"), + CredentialBackends: map[string]logical.Factory{ + "jwt": vaultjwt.Factory, + }, + } + cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{ + HandlerFunc: vaulthttp.Handler, + }) + cluster.Start() + defer cluster.Cleanup() + + vault.TestWaitActive(t, cluster.Cores[0].Core) + client := cluster.Cores[0].Client + + defer os.Setenv(api.EnvVaultAddress, os.Getenv(api.EnvVaultAddress)) + os.Setenv(api.EnvVaultAddress, client.Address()) + + defer os.Setenv(api.EnvVaultCACert, os.Getenv(api.EnvVaultCACert)) + os.Setenv(api.EnvVaultCACert, fmt.Sprintf("%s/ca_cert.pem", cluster.TempDir)) + + // Setup Vault + err := client.Sys().EnableAuthWithOptions("jwt", &api.EnableAuthOptions{ + Type: "jwt", + }) + if err != nil { + t.Fatal(err) + } + + _, err = client.Logical().Write("auth/jwt/config", map[string]interface{}{ + "bound_issuer": "https://team-vault.auth0.com/", + "jwt_validation_pubkeys": agent.TestECDSAPubKey, + }) + if err != nil { + t.Fatal(err) + } + + _, err = client.Logical().Write("auth/jwt/role/test", map[string]interface{}{ + "bound_subject": "r3qXcK2bix9eFECzsU3Sbmh0K16fatW6@clients", + "bound_audiences": "https://vault.plugin.auth.jwt.test", + "user_claim": "https://vault/user", + "groups_claim": "https://vault/groups", + "policies": "test", + "period": "3s", + }) + if err != nil { + t.Fatal(err) + } + + inf, err := ioutil.TempFile("", "auth.jwt.test.") + if err != nil { + t.Fatal(err) + } + in := inf.Name() + inf.Close() + os.Remove(in) + t.Logf("input: %s", in) + + sink1f, err := ioutil.TempFile("", "sink1.jwt.test.") + if err != nil { + t.Fatal(err) + } + sink1 := sink1f.Name() + sink1f.Close() + os.Remove(sink1) + t.Logf("sink1: %s", sink1) + + sink2f, err := ioutil.TempFile("", "sink2.jwt.test.") + if err != nil { + t.Fatal(err) + } + sink2 := sink2f.Name() + sink2f.Close() + os.Remove(sink2) + t.Logf("sink2: %s", sink2) + + conff, err := ioutil.TempFile("", "conf.jwt.test.") + if err != nil { + t.Fatal(err) + } + conf := conff.Name() + conff.Close() + os.Remove(conf) + t.Logf("config: %s", conf) + + jwtToken, _ := agent.GetTestJWT(t) + if err := ioutil.WriteFile(in, []byte(jwtToken), 0600); err != nil { + t.Fatal(err) + } else { + logger.Trace("wrote test jwt", "path", in) + } + + socketff, err := ioutil.TempFile("", "cache.socket.") + if err != nil { + t.Fatal(err) + } + socketf := socketff.Name() + socketff.Close() + os.Remove(socketf) + t.Logf("socketf: %s", socketf) + + config := ` +auto_auth { + method { + type = "jwt" + config = { + role = "test" + path = "%s" + } + } + + sink { + type = "file" + config = { + path = "%s" + } + } + + sink "file" { + config = { + path = "%s" + } + } +} + +cache { + use_auto_auth_token = true + + listener "unix" { + address = "%s" + tls_disable = true + } +} +` + + config = fmt.Sprintf(config, in, sink1, sink2, socketf) + if err := ioutil.WriteFile(conf, []byte(config), 0600); err != nil { + t.Fatal(err) + } else { + logger.Trace("wrote test config", "path", conf) + } + + _, cmd := testAgentCommand(t, logger) + cmd.client = client + + // Kill the command 5 seconds after it starts + go func() { + select { + case <-cmd.ShutdownCh: + case <-time.After(5 * time.Second): + cmd.ShutdownCh <- struct{}{} + } + }() + + originalVaultAgentAddress := os.Getenv(api.EnvVaultAgentAddress) + + // Create a client that talks to the agent + os.Setenv(api.EnvVaultAgentAddress, socketf) + testClient, err := api.NewClient(api.DefaultConfig()) + if err != nil { + t.Fatal(err) + } + os.Setenv(api.EnvVaultAgentAddress, originalVaultAgentAddress) + + // Start the agent + go cmd.Run([]string{"-config", conf}) + + // Give some time for the auto-auth to complete + time.Sleep(1 * time.Second) + + // Invoke lookup self through the agent + secret, err := testClient.Auth().Token().LookupSelf() + if err != nil { + t.Fatal(err) + } + if secret == nil || secret.Data == nil || secret.Data["id"].(string) == "" { + t.Fatalf("failed to perform lookup self through agent") + } +} + func TestExitAfterAuth(t *testing.T) { logger := logging.NewVaultLogger(hclog.Trace) coreConfig := &vault.CoreConfig{ diff --git a/command/base.go b/command/base.go index db37fd37c380..144e16435a80 100644 --- a/command/base.go +++ b/command/base.go @@ -39,6 +39,7 @@ type BaseCommand struct { flagsOnce sync.Once flagAddress string + flagAgentAddress string flagCACert string flagCAPath string flagClientCert string @@ -78,6 +79,9 @@ func (c *BaseCommand) Client() (*api.Client, error) { if c.flagAddress != "" { config.Address = c.flagAddress } + if c.flagAgentAddress != "" { + config.Address = c.flagAgentAddress + } if c.flagOutputCurlString { config.OutputCurlString = c.flagOutputCurlString @@ -220,6 +224,15 @@ func (c *BaseCommand) flagSet(bit FlagSetBit) *FlagSets { } f.StringVar(addrStringVar) + agentAddrStringVar := &StringVar{ + Name: "agent-address", + Target: &c.flagAgentAddress, + EnvVar: "VAULT_AGENT_ADDR", + Completion: complete.PredictAnything, + Usage: "Address of the Agent.", + } + f.StringVar(agentAddrStringVar) + f.StringVar(&StringVar{ Name: "ca-cert", Target: &c.flagCACert, diff --git a/command/server/listener.go b/command/server/listener.go index a1f2f392684c..6546972260f2 100644 --- a/command/server/listener.go +++ b/command/server/listener.go @@ -72,7 +72,7 @@ func listenerWrapProxy(ln net.Listener, config map[string]interface{}) (net.List return newLn, nil } -func listenerWrapTLS( +func ListenerWrapTLS( ln net.Listener, props map[string]string, config map[string]interface{}, diff --git a/command/server/listener_tcp.go b/command/server/listener_tcp.go index 201e124f3aae..02b7b309fa83 100644 --- a/command/server/listener_tcp.go +++ b/command/server/listener_tcp.go @@ -35,7 +35,7 @@ func tcpListenerFactory(config map[string]interface{}, _ io.Writer, ui cli.Ui) ( return nil, nil, nil, err } - ln = tcpKeepAliveListener{ln.(*net.TCPListener)} + ln = TCPKeepAliveListener{ln.(*net.TCPListener)} ln, err = listenerWrapProxy(ln, config) if err != nil { @@ -94,20 +94,20 @@ func tcpListenerFactory(config map[string]interface{}, _ io.Writer, ui cli.Ui) ( config["x_forwarded_for_reject_not_authorized"] = true } - return listenerWrapTLS(ln, props, config, ui) + return ListenerWrapTLS(ln, props, config, ui) } -// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted +// TCPKeepAliveListener sets TCP keep-alive timeouts on accepted // connections. It's used by ListenAndServe and ListenAndServeTLS so // dead TCP connections (e.g. closing laptop mid-download) eventually // go away. // // This is copied directly from the Go source code. -type tcpKeepAliveListener struct { +type TCPKeepAliveListener struct { *net.TCPListener } -func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) { +func (ln TCPKeepAliveListener) Accept() (c net.Conn, err error) { tc, err := ln.AcceptTCP() if err != nil { return