diff --git a/changelog/14131.txt b/changelog/14131.txt new file mode 100644 index 000000000000..e2d2d87b8688 --- /dev/null +++ b/changelog/14131.txt @@ -0,0 +1,3 @@ +```release-note:improvement +cli: interactive CLI for login mfa +``` diff --git a/command/base.go b/command/base.go index 558ec4993681..9521e17225b6 100644 --- a/command/base.go +++ b/command/base.go @@ -55,6 +55,7 @@ type BaseCommand struct { flagFormat string flagField string flagOutputCurlString bool + flagNonInteractive bool flagMFA []string @@ -393,6 +394,13 @@ func (c *BaseCommand) flagSet(bit FlagSetBit) *FlagSets { "This can be specified multiple times.", }) + f.BoolVar(&BoolVar{ + Name: "non-interactive", + Target: &c.flagNonInteractive, + Default: false, + Usage: "When set true, prevents asking the user for input via the terminal.", + }) + } if bit&(FlagSetOutputField|FlagSetOutputFormat) != 0 { diff --git a/command/write.go b/command/write.go index c110e9da26a8..0de7299eb8a6 100644 --- a/command/write.go +++ b/command/write.go @@ -6,6 +6,8 @@ import ( "os" "strings" + "github.com/hashicorp/vault/sdk/logical" + "github.com/mattn/go-isatty" "github.com/mitchellh/cli" "github.com/posener/complete" ) @@ -15,6 +17,13 @@ var ( _ cli.CommandAutocomplete = (*WriteCommand)(nil) ) +// MFAMethodInfo contains the information about an MFA method +type MFAMethodInfo struct { + methodID string + methodType string + usePasscode bool +} + // WriteCommand is a Command that puts data into the Vault. type WriteCommand struct { *BaseCommand @@ -147,10 +156,103 @@ func (c *WriteCommand) Run(args []string) int { } if secret != nil && secret.Auth != nil && secret.Auth.MFARequirement != nil { + if c.isInteractiveEnabled(len(secret.Auth.MFARequirement.MFAConstraints)) { + // Currently, if there is only one MFA method configured, the login + // request is validated interactively + methodInfo := c.getMFAMethodInfo(secret.Auth.MFARequirement.MFAConstraints) + if methodInfo.methodID != "" { + return c.validateMFA(secret.Auth.MFARequirement.MFARequestID, methodInfo) + } + } c.UI.Warn(wrapAtLength("A login request was issued that is subject to "+ "MFA validation. Please make sure to validate the login by sending another "+ - "request to mfa/validate endpoint.") + "\n") + "request to sys/mfa/validate endpoint.") + "\n") + } + + // Handle single field output + if c.flagField != "" { + return PrintRawField(c.UI, secret, c.flagField) + } + + return OutputSecret(c.UI, secret) +} + +func (c *WriteCommand) isInteractiveEnabled(mfaConstraintLen int) bool { + if mfaConstraintLen != 1 || !isatty.IsTerminal(os.Stdin.Fd()) { + return false + } + + if !c.flagNonInteractive { + return true + } + + return false +} + +// getMFAMethodInfo returns MFA method information only if one MFA method is +// configured. +func (c *WriteCommand) getMFAMethodInfo(mfaConstraintAny map[string]*logical.MFAConstraintAny) MFAMethodInfo { + for _, mfaConstraint := range mfaConstraintAny { + if len(mfaConstraint.Any) != 1 { + return MFAMethodInfo{} + } + + return MFAMethodInfo{ + methodType: mfaConstraint.Any[0].Type, + methodID: mfaConstraint.Any[0].ID, + usePasscode: mfaConstraint.Any[0].UsesPasscode, + } + } + + return MFAMethodInfo{} +} + +func (c *WriteCommand) validateMFA(reqID string, methodInfo MFAMethodInfo) int { + var passcode string + var err error + if methodInfo.usePasscode { + passcode, err = c.UI.AskSecret(fmt.Sprintf("Enter the passphrase for methodID %q of type %q:", methodInfo.methodID, methodInfo.methodType)) + if err != nil { + c.UI.Error(fmt.Sprintf("failed to read the passphrase with error %q. please validate the login by sending a request to sys/mfa/validate", err.Error())) + return 2 + } + } else { + c.UI.Warn("Asking Vault to perform MFA validation with upstream service. " + + "You should receive a push notification in your authenticator app shortly") } + + // passcode could be an empty string + mfaPayload := map[string][]string{ + methodInfo.methodID: {passcode}, + } + + client, err := c.Client() + if err != nil { + c.UI.Error(err.Error()) + return 2 + } + + path := "sys/mfa/validate" + + secret, err := client.Logical().Write(path, map[string]interface{}{ + "mfa_request_id": reqID, + "mfa_payload": mfaPayload, + }) + if err != nil { + c.UI.Error(err.Error()) + if secret != nil { + OutputSecret(c.UI, secret) + } + return 2 + } + if secret == nil { + // Don't output anything unless using the "table" format + if Format(c.UI) == "table" { + c.UI.Info(fmt.Sprintf("Success! Data written to: %s", path)) + } + return 0 + } + // Handle single field output if c.flagField != "" { return PrintRawField(c.UI, secret, c.flagField) diff --git a/vault/core.go b/vault/core.go index 1bca8a2cea35..c1579d12121f 100644 --- a/vault/core.go +++ b/vault/core.go @@ -2097,7 +2097,6 @@ func (s standardUnsealStrategy) unseal(ctx context.Context, logger log.Logger, c if err := c.setupQuotas(ctx, false); err != nil { return err } - c.setupCachedMFAResponseAuth() if err := c.setupHeaderHMACKey(ctx, false); err != nil { diff --git a/vault/core_util.go b/vault/core_util.go index e057c35a2526..965cc32aed41 100644 --- a/vault/core_util.go +++ b/vault/core_util.go @@ -135,10 +135,6 @@ func (c *Core) collectNamespaces() []*namespace.Namespace { } } -func (c *Core) namepaceByPath(string) *namespace.Namespace { - return namespace.RootNamespace -} - func (c *Core) HasWALState(required *logical.WALState, perfStandby bool) bool { return true } diff --git a/vault/external_tests/identity/login_mfa_totp_test.go b/vault/external_tests/identity/login_mfa_totp_test.go index dd7b25628ea5..19869bdb7825 100644 --- a/vault/external_tests/identity/login_mfa_totp_test.go +++ b/vault/external_tests/identity/login_mfa_totp_test.go @@ -1,12 +1,14 @@ package identity import ( + "context" "fmt" "strings" "testing" "time" "github.com/hashicorp/vault/api" + "github.com/hashicorp/vault/audit" "github.com/hashicorp/vault/builtin/credential/userpass" "github.com/hashicorp/vault/builtin/logical/totp" vaulthttp "github.com/hashicorp/vault/http" @@ -14,29 +16,45 @@ import ( "github.com/hashicorp/vault/vault" ) -var loginMFACoreConfig = &vault.CoreConfig{ - CredentialBackends: map[string]logical.Factory{ - "userpass": userpass.Factory, - }, - LogicalBackends: map[string]logical.Factory{ - "totp": totp.Factory, +func TestLoginMfaGenerateTOTPTestAuditIncluded(t *testing.T) { + var noop *vault.NoopAudit + + cluster := vault.NewTestCluster(t, &vault.CoreConfig{ + CredentialBackends: map[string]logical.Factory{ + "userpass": userpass.Factory, + }, + LogicalBackends: map[string]logical.Factory{ + "totp": totp.Factory, + }, + AuditBackends: map[string]audit.Factory{ + "noop": func(ctx context.Context, config *audit.BackendConfig) (audit.Backend, error) { + noop = &vault.NoopAudit{ + Config: config, + } + return noop, nil + }, + }, }, -} + &vault.TestClusterOptions{ + HandlerFunc: vaulthttp.Handler, + }) -func TestLoginMfaGenerateTOTPRoleTest(t *testing.T) { - cluster := vault.NewTestCluster(t, loginMFACoreConfig, &vault.TestClusterOptions{ - HandlerFunc: vaulthttp.Handler, - }) cluster.Start() defer cluster.Cleanup() client := cluster.Cores[0].Client + // Enable the audit backend + err := client.Sys().EnableAuditWithOptions("noop", &api.EnableAuditOptions{Type: "noop"}) + if err != nil { + t.Fatal(err) + } + // Mount the TOTP backend mountInfo := &api.MountInput{ Type: "totp", } - err := client.Sys().Mount("totp", mountInfo) + err = client.Sys().Mount("totp", mountInfo) if err != nil { t.Fatalf("failed to mount totp backend: %v", err) } @@ -254,6 +272,24 @@ func TestLoginMfaGenerateTOTPRoleTest(t *testing.T) { t.Fatalf("MFA failed: %v", err) } + if secret.Auth == nil || secret.Auth.ClientToken == "" { + t.Fatalf("successful mfa validation did not return a client token") + } + + if noop.Req == nil { + t.Fatalf("no request was logged in audit log") + } + var found bool + for _, req := range noop.Req { + if req.Path == "sys/mfa/validate" { + found = true + break + } + } + if !found { + t.Fatalf("mfa/validate was not logged in audit log") + } + // check for login request expiration secret, err = user2Client.Logical().Write("auth/userpass/login/testuser", map[string]interface{}{ "password": "testpassword",