Skip to content

Commit

Permalink
Fix OTP prompt (#35537)
Browse files Browse the repository at this point in the history
* Fix otp prompt; remove related dead code.

* Implement basic read waiting for fake readers in tests.

* Add regression test.
  • Loading branch information
Joerger authored Dec 12, 2023
1 parent 728d709 commit e99b342
Show file tree
Hide file tree
Showing 4 changed files with 193 additions and 11 deletions.
26 changes: 22 additions & 4 deletions api/utils/prompt/mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,18 @@ package prompt

import (
"context"
"errors"
"sync"
"time"

"github.com/gravitational/trace"
)

type FakeReplyFunc func(context.Context) (string, error)

type FakeReader struct {
mu sync.Mutex
replies []FakeReplyFunc
mu sync.Mutex
replies []FakeReplyFunc
waitingForReply chan struct{}
}

// NewFakeReader returns a fake that can be used in place of a ContextReader.
Expand All @@ -42,6 +45,10 @@ func (r *FakeReader) AddReply(fn FakeReplyFunc) *FakeReader {
r.mu.Lock()
defer r.mu.Unlock()
r.replies = append(r.replies, fn)
if r.waitingForReply != nil {
close(r.waitingForReply)
r.waitingForReply = nil
}
return r
}

Expand All @@ -60,8 +67,19 @@ func (r *FakeReader) AddError(err error) *FakeReader {
func (r *FakeReader) ReadContext(ctx context.Context) ([]byte, error) {
r.mu.Lock()
if len(r.replies) == 0 {
// wait for a reply
wait := make(chan struct{})
r.waitingForReply = wait
r.mu.Unlock()
return nil, errors.New("no fake replies available")

select {
case <-ctx.Done():
return nil, trace.Wrap(ctx.Err())
case <-time.After(5 * time.Second):
return nil, trace.BadParameter("no fake replies available after wait")
case <-wait:
r.mu.Lock()
}
}

// Pop first reply.
Expand Down
6 changes: 0 additions & 6 deletions lib/auth/authclient/authclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ import (
"github.com/gravitational/teleport/api/breaker"
apiclient "github.com/gravitational/teleport/api/client"
"github.com/gravitational/teleport/api/client/webclient"
"github.com/gravitational/teleport/api/mfa"
"github.com/gravitational/teleport/lib/auth"
"github.com/gravitational/teleport/lib/reversetunnelclient"
"github.com/gravitational/teleport/lib/utils"
Expand All @@ -52,9 +51,6 @@ type Config struct {
CircuitBreakerConfig breaker.Config
// DialTimeout determines how long to wait for dialing to succeed before aborting.
DialTimeout time.Duration
// MFAPromptConstructor is used to create MFA prompts when needed.
// If nil, the client will not prompt for MFA.
MFAPromptConstructor mfa.PromptConstructor
// Insecure turns off TLS certificate verification when enabled.
Insecure bool
}
Expand Down Expand Up @@ -97,7 +93,6 @@ func connectViaAuthDirect(ctx context.Context, cfg *Config) (*auth.Client, error
CircuitBreakerConfig: cfg.CircuitBreakerConfig,
InsecureAddressDiscovery: cfg.Insecure,
DialTimeout: cfg.DialTimeout,
MFAPromptConstructor: cfg.MFAPromptConstructor,
})
if err != nil {
return nil, trace.Wrap(err)
Expand Down Expand Up @@ -149,7 +144,6 @@ func connectViaProxyTunnel(ctx context.Context, cfg *Config) (*auth.Client, erro
Credentials: []apiclient.Credentials{
apiclient.LoadTLS(cfg.TLS),
},
MFAPromptConstructor: cfg.MFAPromptConstructor,
})
if err != nil {
return nil, trace.Wrap(err)
Expand Down
2 changes: 1 addition & 1 deletion lib/client/mfa/cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ func (c *CLIPrompt) Run(ctx context.Context, chal *proto.MFAAuthenticateChalleng
spawnGoroutines := func(ctx context.Context, wg *sync.WaitGroup, respC chan<- MFAGoroutineResponse) {
// Use variables below to cancel OTP reads and make sure the goroutine exited.
otpCtx, otpCancel := context.WithCancel(ctx)
defer otpCancel()
otpDone := make(chan struct{})
otpCancelAndWait := func() {
otpCancel()
Expand All @@ -74,6 +73,7 @@ func (c *CLIPrompt) Run(ctx context.Context, chal *proto.MFAAuthenticateChalleng
wg.Add(1)
go func() {
defer wg.Done()
defer otpCancel()
defer close(otpDone)

// Let Webauthn take the prompt below if applicable.
Expand Down
170 changes: 170 additions & 0 deletions lib/client/mfa/cli_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
/*
* Teleport
* Copyright (C) 2023 Gravitational, Inc.
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/

package mfa_test

import (
"bytes"
"context"
"testing"
"time"

"github.com/gravitational/trace"
"github.com/stretchr/testify/require"

"github.com/gravitational/teleport/api/client/proto"
webauthnpb "github.com/gravitational/teleport/api/types/webauthn"
"github.com/gravitational/teleport/api/utils/prompt"
wancli "github.com/gravitational/teleport/lib/auth/webauthncli"
wantypes "github.com/gravitational/teleport/lib/auth/webauthntypes"
"github.com/gravitational/teleport/lib/client/mfa"
)

func TestCLIPrompt(t *testing.T) {
ctx := context.Background()

for _, tc := range []struct {
name string
stdin string
challenge *proto.MFAAuthenticateChallenge
expectErr error
expectStdOut string
expectResp *proto.MFAAuthenticateResponse
}{
{
name: "OK webauthn",
expectStdOut: "Tap any security key\n",
challenge: &proto.MFAAuthenticateChallenge{
WebauthnChallenge: &webauthnpb.CredentialAssertion{},
},
expectResp: &proto.MFAAuthenticateResponse{
Response: &proto.MFAAuthenticateResponse_Webauthn{
Webauthn: &webauthnpb.CredentialAssertionResponse{},
},
},
}, {
name: "OK totp",
expectStdOut: "Enter an OTP code from a device:\n",
stdin: "123456",
challenge: &proto.MFAAuthenticateChallenge{
TOTP: &proto.TOTPChallenge{},
},
expectResp: &proto.MFAAuthenticateResponse{
Response: &proto.MFAAuthenticateResponse_TOTP{
TOTP: &proto.TOTPResponse{
Code: "123456",
},
},
},
}, {
name: "OK webauthn or totp choose webauthn",
expectStdOut: "Tap any security key or enter a code from a OTP device\n",
challenge: &proto.MFAAuthenticateChallenge{
WebauthnChallenge: &webauthnpb.CredentialAssertion{},
TOTP: &proto.TOTPChallenge{},
},
expectResp: &proto.MFAAuthenticateResponse{
Response: &proto.MFAAuthenticateResponse_Webauthn{
Webauthn: &webauthnpb.CredentialAssertionResponse{},
},
},
}, {
name: "OK webauthn or totp choose totp",
expectStdOut: "Tap any security key or enter a code from a OTP device\n",
stdin: "123456",
challenge: &proto.MFAAuthenticateChallenge{
WebauthnChallenge: &webauthnpb.CredentialAssertion{},
TOTP: &proto.TOTPChallenge{},
},
expectResp: &proto.MFAAuthenticateResponse{
Response: &proto.MFAAuthenticateResponse_TOTP{
TOTP: &proto.TOTPResponse{
Code: "123456",
},
},
},
}, {
name: "NOK no webauthn response",
expectStdOut: "Tap any security key\n",
challenge: &proto.MFAAuthenticateChallenge{
WebauthnChallenge: &webauthnpb.CredentialAssertion{},
},
expectErr: context.DeadlineExceeded,
}, {
name: "NOK no totp response",
expectStdOut: "Enter an OTP code from a device:\n",
challenge: &proto.MFAAuthenticateChallenge{
TOTP: &proto.TOTPChallenge{},
},
expectErr: context.DeadlineExceeded,
}, {
name: "NOK no webauthn or totp response",
expectStdOut: "Tap any security key or enter a code from a OTP device\n",
challenge: &proto.MFAAuthenticateChallenge{
WebauthnChallenge: &webauthnpb.CredentialAssertion{},
TOTP: &proto.TOTPChallenge{},
},
expectErr: context.DeadlineExceeded,
},
} {
t.Run(tc.name, func(t *testing.T) {
ctx, cancel := context.WithTimeout(ctx, 100*time.Millisecond)
defer cancel()

oldStdin := prompt.Stdin()
t.Cleanup(func() { prompt.SetStdin(oldStdin) })

stdin := prompt.NewFakeReader()
if tc.stdin != "" {
stdin.AddString(tc.stdin)
}
prompt.SetStdin(stdin)

cfg := mfa.NewPromptConfig("proxy.example.com")
cfg.AllowStdinHijack = true
cfg.WebauthnSupported = true
cfg.WebauthnLoginFunc = func(ctx context.Context, origin string, assertion *wantypes.CredentialAssertion, prompt wancli.LoginPrompt, opts *wancli.LoginOpts) (*proto.MFAAuthenticateResponse, string, error) {
if _, err := prompt.PromptTouch(); err != nil {
return nil, "", trace.Wrap(err)
}

if tc.expectResp.GetWebauthn() == nil {
<-ctx.Done()
return nil, "", trace.Wrap(ctx.Err())
}

return tc.expectResp, "", nil
}

buffer := make([]byte, 0, 100)
out := bytes.NewBuffer(buffer)

prompt := mfa.NewCLIPrompt(cfg, out)
resp, err := prompt.Run(ctx, tc.challenge)

if tc.expectErr != nil {
require.ErrorIs(t, err, tc.expectErr)
} else {
require.NoError(t, err)
}

require.Equal(t, tc.expectResp, resp)
require.Equal(t, tc.expectStdOut, out.String())
})
}
}

0 comments on commit e99b342

Please sign in to comment.