diff --git a/api/utils/prompt/mock.go b/api/utils/prompt/mock.go index de707584dcad3..2ca53b68d2d6d 100644 --- a/api/utils/prompt/mock.go +++ b/api/utils/prompt/mock.go @@ -16,15 +16,18 @@ package prompt import ( "context" - "errors" "sync" + "time" + + "github.com/gravitational/trace" ) type FakeReplyFunc func(context.Context) (string, error) type FakeReader struct { - mu sync.Mutex - replies []FakeReplyFunc + mu sync.Mutex + replies []FakeReplyFunc + waitingForReply chan struct{} } // NewFakeReader returns a fake that can be used in place of a ContextReader. @@ -42,6 +45,10 @@ func (r *FakeReader) AddReply(fn FakeReplyFunc) *FakeReader { r.mu.Lock() defer r.mu.Unlock() r.replies = append(r.replies, fn) + if r.waitingForReply != nil { + close(r.waitingForReply) + r.waitingForReply = nil + } return r } @@ -60,8 +67,19 @@ func (r *FakeReader) AddError(err error) *FakeReader { func (r *FakeReader) ReadContext(ctx context.Context) ([]byte, error) { r.mu.Lock() if len(r.replies) == 0 { + // wait for a reply + wait := make(chan struct{}) + r.waitingForReply = wait r.mu.Unlock() - return nil, errors.New("no fake replies available") + + select { + case <-ctx.Done(): + return nil, trace.Wrap(ctx.Err()) + case <-time.After(5 * time.Second): + return nil, trace.BadParameter("no fake replies available after wait") + case <-wait: + r.mu.Lock() + } } // Pop first reply. diff --git a/lib/auth/authclient/authclient.go b/lib/auth/authclient/authclient.go index aafd4498fcb85..449caa7fcf7af 100644 --- a/lib/auth/authclient/authclient.go +++ b/lib/auth/authclient/authclient.go @@ -32,7 +32,6 @@ import ( "github.com/gravitational/teleport/api/breaker" apiclient "github.com/gravitational/teleport/api/client" "github.com/gravitational/teleport/api/client/webclient" - "github.com/gravitational/teleport/api/mfa" "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/reversetunnelclient" "github.com/gravitational/teleport/lib/utils" @@ -52,9 +51,6 @@ type Config struct { CircuitBreakerConfig breaker.Config // DialTimeout determines how long to wait for dialing to succeed before aborting. DialTimeout time.Duration - // MFAPromptConstructor is used to create MFA prompts when needed. - // If nil, the client will not prompt for MFA. - MFAPromptConstructor mfa.PromptConstructor // Insecure turns off TLS certificate verification when enabled. Insecure bool } @@ -97,7 +93,6 @@ func connectViaAuthDirect(ctx context.Context, cfg *Config) (*auth.Client, error CircuitBreakerConfig: cfg.CircuitBreakerConfig, InsecureAddressDiscovery: cfg.Insecure, DialTimeout: cfg.DialTimeout, - MFAPromptConstructor: cfg.MFAPromptConstructor, }) if err != nil { return nil, trace.Wrap(err) @@ -149,7 +144,6 @@ func connectViaProxyTunnel(ctx context.Context, cfg *Config) (*auth.Client, erro Credentials: []apiclient.Credentials{ apiclient.LoadTLS(cfg.TLS), }, - MFAPromptConstructor: cfg.MFAPromptConstructor, }) if err != nil { return nil, trace.Wrap(err) diff --git a/lib/client/mfa/cli.go b/lib/client/mfa/cli.go index 853d18b6164b6..f434d68c6edd3 100644 --- a/lib/client/mfa/cli.go +++ b/lib/client/mfa/cli.go @@ -62,7 +62,6 @@ func (c *CLIPrompt) Run(ctx context.Context, chal *proto.MFAAuthenticateChalleng spawnGoroutines := func(ctx context.Context, wg *sync.WaitGroup, respC chan<- MFAGoroutineResponse) { // Use variables below to cancel OTP reads and make sure the goroutine exited. otpCtx, otpCancel := context.WithCancel(ctx) - defer otpCancel() otpDone := make(chan struct{}) otpCancelAndWait := func() { otpCancel() @@ -74,6 +73,7 @@ func (c *CLIPrompt) Run(ctx context.Context, chal *proto.MFAAuthenticateChalleng wg.Add(1) go func() { defer wg.Done() + defer otpCancel() defer close(otpDone) // Let Webauthn take the prompt below if applicable. diff --git a/lib/client/mfa/cli_test.go b/lib/client/mfa/cli_test.go new file mode 100644 index 0000000000000..ce15325a4d78c --- /dev/null +++ b/lib/client/mfa/cli_test.go @@ -0,0 +1,170 @@ +/* + * Teleport + * Copyright (C) 2023 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package mfa_test + +import ( + "bytes" + "context" + "testing" + "time" + + "github.com/gravitational/trace" + "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport/api/client/proto" + webauthnpb "github.com/gravitational/teleport/api/types/webauthn" + "github.com/gravitational/teleport/api/utils/prompt" + wancli "github.com/gravitational/teleport/lib/auth/webauthncli" + wantypes "github.com/gravitational/teleport/lib/auth/webauthntypes" + "github.com/gravitational/teleport/lib/client/mfa" +) + +func TestCLIPrompt(t *testing.T) { + ctx := context.Background() + + for _, tc := range []struct { + name string + stdin string + challenge *proto.MFAAuthenticateChallenge + expectErr error + expectStdOut string + expectResp *proto.MFAAuthenticateResponse + }{ + { + name: "OK webauthn", + expectStdOut: "Tap any security key\n", + challenge: &proto.MFAAuthenticateChallenge{ + WebauthnChallenge: &webauthnpb.CredentialAssertion{}, + }, + expectResp: &proto.MFAAuthenticateResponse{ + Response: &proto.MFAAuthenticateResponse_Webauthn{ + Webauthn: &webauthnpb.CredentialAssertionResponse{}, + }, + }, + }, { + name: "OK totp", + expectStdOut: "Enter an OTP code from a device:\n", + stdin: "123456", + challenge: &proto.MFAAuthenticateChallenge{ + TOTP: &proto.TOTPChallenge{}, + }, + expectResp: &proto.MFAAuthenticateResponse{ + Response: &proto.MFAAuthenticateResponse_TOTP{ + TOTP: &proto.TOTPResponse{ + Code: "123456", + }, + }, + }, + }, { + name: "OK webauthn or totp choose webauthn", + expectStdOut: "Tap any security key or enter a code from a OTP device\n", + challenge: &proto.MFAAuthenticateChallenge{ + WebauthnChallenge: &webauthnpb.CredentialAssertion{}, + TOTP: &proto.TOTPChallenge{}, + }, + expectResp: &proto.MFAAuthenticateResponse{ + Response: &proto.MFAAuthenticateResponse_Webauthn{ + Webauthn: &webauthnpb.CredentialAssertionResponse{}, + }, + }, + }, { + name: "OK webauthn or totp choose totp", + expectStdOut: "Tap any security key or enter a code from a OTP device\n", + stdin: "123456", + challenge: &proto.MFAAuthenticateChallenge{ + WebauthnChallenge: &webauthnpb.CredentialAssertion{}, + TOTP: &proto.TOTPChallenge{}, + }, + expectResp: &proto.MFAAuthenticateResponse{ + Response: &proto.MFAAuthenticateResponse_TOTP{ + TOTP: &proto.TOTPResponse{ + Code: "123456", + }, + }, + }, + }, { + name: "NOK no webauthn response", + expectStdOut: "Tap any security key\n", + challenge: &proto.MFAAuthenticateChallenge{ + WebauthnChallenge: &webauthnpb.CredentialAssertion{}, + }, + expectErr: context.DeadlineExceeded, + }, { + name: "NOK no totp response", + expectStdOut: "Enter an OTP code from a device:\n", + challenge: &proto.MFAAuthenticateChallenge{ + TOTP: &proto.TOTPChallenge{}, + }, + expectErr: context.DeadlineExceeded, + }, { + name: "NOK no webauthn or totp response", + expectStdOut: "Tap any security key or enter a code from a OTP device\n", + challenge: &proto.MFAAuthenticateChallenge{ + WebauthnChallenge: &webauthnpb.CredentialAssertion{}, + TOTP: &proto.TOTPChallenge{}, + }, + expectErr: context.DeadlineExceeded, + }, + } { + t.Run(tc.name, func(t *testing.T) { + ctx, cancel := context.WithTimeout(ctx, 100*time.Millisecond) + defer cancel() + + oldStdin := prompt.Stdin() + t.Cleanup(func() { prompt.SetStdin(oldStdin) }) + + stdin := prompt.NewFakeReader() + if tc.stdin != "" { + stdin.AddString(tc.stdin) + } + prompt.SetStdin(stdin) + + cfg := mfa.NewPromptConfig("proxy.example.com") + cfg.AllowStdinHijack = true + cfg.WebauthnSupported = true + cfg.WebauthnLoginFunc = func(ctx context.Context, origin string, assertion *wantypes.CredentialAssertion, prompt wancli.LoginPrompt, opts *wancli.LoginOpts) (*proto.MFAAuthenticateResponse, string, error) { + if _, err := prompt.PromptTouch(); err != nil { + return nil, "", trace.Wrap(err) + } + + if tc.expectResp.GetWebauthn() == nil { + <-ctx.Done() + return nil, "", trace.Wrap(ctx.Err()) + } + + return tc.expectResp, "", nil + } + + buffer := make([]byte, 0, 100) + out := bytes.NewBuffer(buffer) + + prompt := mfa.NewCLIPrompt(cfg, out) + resp, err := prompt.Run(ctx, tc.challenge) + + if tc.expectErr != nil { + require.ErrorIs(t, err, tc.expectErr) + } else { + require.NoError(t, err) + } + + require.Equal(t, tc.expectResp, resp) + require.Equal(t, tc.expectStdOut, out.String()) + }) + } +}