diff --git a/internal/cmd/login.go b/internal/cmd/login.go index 57075a9fb4..58e9c51ad1 100644 --- a/internal/cmd/login.go +++ b/internal/cmd/login.go @@ -34,8 +34,9 @@ type loginCmdOptions struct { AccessKey string Provider string SkipTLSVerify bool - // TODO: add option to pass a trusted certificate - NonInteractive bool + // TODO: add flag for trusted certificate + TrustedCertificate []byte + NonInteractive bool } type loginMethod int8 @@ -84,15 +85,32 @@ $ infra login --key 1M4CWy9wF5.fAKeKEy5sMLH9ZZzAur0ZIjy`, } func login(cli *CLI, options loginCmdOptions) error { - var err error + config, err := readConfig() + if err != nil { + return err + } if options.Server == "" { - options.Server, err = promptServer(cli, options) + if options.NonInteractive { + return Error{Message: "Non-interactive login requires the [SERVER] argument"} + } + + options.Server, err = promptServer(cli, config) if err != nil { return err } } + if len(options.TrustedCertificate) == 0 { + // Attempt to find a previously trusted certificate + for _, hc := range config.Hosts { + // TODO: comparison fails if one of these includes a scheme + if hc.Host == options.Server { + options.TrustedCertificate = hc.TrustedCertificate + } + } + } + lc, err := newLoginClient(cli, options) if err != nil { return err @@ -376,76 +394,66 @@ type loginClient struct { // Only used when logging in or switching to a new session, since user has no credentials. Otherwise, use defaultAPIClient(). func newLoginClient(cli *CLI, options loginCmdOptions) (loginClient, error) { - c := loginClient{} - if !options.SkipTLSVerify { - // Prompt user only if server fails the TLS verification - if err := attemptTLSRequest(options.Server); err != nil { - var uaErr x509.UnknownAuthorityError - if !errors.As(err, &uaErr) { - return c, err - } + cfg := &ClientHostConfig{ + TrustedCertificate: options.TrustedCertificate, + SkipTLSVerify: options.SkipTLSVerify, + } + c := loginClient{ + APIClient: apiClient(options.Server, "", httpTransportForHostConfig(cfg)), + TrustedCertificate: options.TrustedCertificate, + } + if options.SkipTLSVerify { + return c, nil + } - if options.NonInteractive { - // TODO: add the --tls-ca flag - // TODO: give a different error if the flag was set - return c, Error{ - Message: "The authenticity of the server could not be verified. " + - "Use the --tls-ca flag to specify a trusted CA, or run " + - "in interactive mode.", - } - } + // Prompt user only if server fails the TLS verification + if err := attemptTLSRequest(c.APIClient); err != nil { + var uaErr x509.UnknownAuthorityError + if !errors.As(err, &uaErr) { + return c, err + } - if err = promptVerifyTLSCert(cli, uaErr.Cert); err != nil { - return c, err - } - // TODO: cleanup - pool, err := x509.SystemCertPool() - if err != nil { - return c, err - } - pool.AddCert(uaErr.Cert) - transport := &http.Transport{ - TLSClientConfig: &tls.Config{ - // set min version to the same as default to make gosec linter happy - MinVersion: tls.VersionTLS12, - RootCAs: pool, - }, + if options.NonInteractive { + // TODO: add the --tls-ca flag + // TODO: give a different error if the flag was set + return c, Error{ + Message: "The authenticity of the server could not be verified. " + + "Use the --tls-ca flag to specify a trusted CA, or run " + + "in interactive mode.", } - c.APIClient = apiClient(options.Server, "", transport) - c.TrustedCertificate = uaErr.Cert.Raw - return c, nil } - } - transport := &http.Transport{ - TLSClientConfig: &tls.Config{ - //nolint:gosec // We may purposely set insecureskipverify via a flag - InsecureSkipVerify: options.SkipTLSVerify, - // TODO: read options.TrustedCertificate - //RootCAs: pool, - }, + if err = promptVerifyTLSCert(cli, uaErr.Cert); err != nil { + return c, err + } + + pool, err := x509.SystemCertPool() + if err != nil { + return c, err + } + pool.AddCert(uaErr.Cert) + transport := &http.Transport{ + TLSClientConfig: &tls.Config{ + // set min version to the same as default to make gosec linter happy + MinVersion: tls.VersionTLS12, + RootCAs: pool, + }, + } + c.APIClient = apiClient(options.Server, "", transport) + c.TrustedCertificate = uaErr.Cert.Raw } - c.APIClient = apiClient(options.Server, "", transport) return c, nil } -func attemptTLSRequest(host string) error { - // TODO: use apiClient here so that we set the right user agent, and can re-use - // error handling from the client. Use the /api/version endpoint. - reqURL, err := urlx.Parse(host) - if err != nil { - return fmt.Errorf("failed to parse the server hostname: %w", err) - } - reqURL.Scheme = "https" - - req, err := http.NewRequestWithContext(context.TODO(), http.MethodGet, reqURL.String(), nil) +func attemptTLSRequest(client *api.Client) error { + req, err := http.NewRequestWithContext(context.TODO(), http.MethodGet, client.URL, nil) if err != nil { return fmt.Errorf("failed to create request: %w", err) } - logging.S.Debugf("call server: test tls for %q", host) + logging.S.Debugf("call server: test tls for %q", client.URL) urlErr := &url.Error{} - res, err := http.DefaultClient.Do(req) + res, err := client.HTTP.Do(req) switch { case err == nil: res.Body.Close() @@ -604,16 +612,7 @@ to manually verify the certificate can be trusted. } // Returns the host address of the Infra server that user would like to log into -func promptServer(cli *CLI, options loginCmdOptions) (string, error) { - if options.NonInteractive { - return "", Error{Message: "Non-interactive login requires the [SERVER] argument"} - } - - config, err := readConfig() - if err != nil { - return "", err - } - +func promptServer(cli *CLI, config *ClientConfig) (string, error) { servers := config.Hosts if len(servers) == 0 { diff --git a/internal/cmd/login_test.go b/internal/cmd/login_test.go index 1bbecb8545..5ac21226a4 100644 --- a/internal/cmd/login_test.go +++ b/internal/cmd/login_test.go @@ -24,7 +24,7 @@ import ( "github.com/infrahq/infra/uid" ) -func TestLoginCmd(t *testing.T) { +func TestLoginCmd_FirstLoginSetupAdmin(t *testing.T) { dir := t.TempDir() t.Setenv("HOME", dir) t.Setenv("USERPROFILE", dir) // Windows @@ -45,30 +45,6 @@ func TestLoginCmd(t *testing.T) { assert.Check(t, srv.Run(ctx)) }() - runStep(t, "reject server certificate", func(t *testing.T) { - ctx, cancel := context.WithCancel(ctx) - t.Cleanup(cancel) - - console := newConsole(t) - ctx = PatchCLIWithPTY(ctx, console.Tty()) - - g, ctx := errgroup.WithContext(ctx) - g.Go(func() error { - // TODO: why isn't this working without --non-interactive=false? the other test works - return Run(ctx, "login", "--non-interactive=false", srv.Addrs.HTTPS.String()) - }) - exp := expector{console: console} - exp.ExpectString(t, "verify the certificate can be trusted") - exp.Send(t, "I do not trust this certificate\n") - - assert.ErrorIs(t, g.Wait(), terminal.InterruptErr) - - // Check we haven't persisted any certificates - cfg, err := readConfig() - assert.NilError(t, err) - assert.Equal(t, len(cfg.Hosts), 0) - }) - runStep(t, "first login prompts for setup", func(t *testing.T) { ctx, cancel := context.WithCancel(ctx) t.Cleanup(cancel) @@ -120,63 +96,6 @@ func TestLoginCmd(t *testing.T) { expected := clientcmdapi.Config{} assert.DeepEqual(t, expected, kubeCfg, cmpopts.EquateEmpty()) }) - - runStep(t, "trust server certificate", func(t *testing.T) { - ctx, cancel := context.WithCancel(ctx) - t.Cleanup(cancel) - - // Create an access key for login - cfg, err := readConfig() - assert.NilError(t, err) - assert.Equal(t, len(cfg.Hosts), 1) - userID, _ := cfg.Hosts[0].PolymorphicID.ID() - client, err := defaultAPIClient() - assert.NilError(t, err) - - resp, err := client.CreateAccessKey(&api.CreateAccessKeyRequest{ - UserID: userID, - TTL: api.Duration(opts.SessionDuration + time.Minute), - ExtensionDeadline: api.Duration(time.Minute), - }) - assert.NilError(t, err) - accessKey := resp.AccessKey - - // Erase local data for previous login session - assert.NilError(t, writeConfig(&ClientConfig{Version: clientConfigVersion})) - - console := newConsole(t) - ctx = PatchCLIWithPTY(ctx, console.Tty()) - - g, ctx := errgroup.WithContext(ctx) - g.Go(func() error { - // TODO: why isn't this working without --non-interactive=false? the other test works - return Run(ctx, "login", "--non-interactive=false", "--key", accessKey, srv.Addrs.HTTPS.String()) - }) - exp := expector{console: console} - exp.ExpectString(t, "verify the certificate can be trusted") - exp.Send(t, "Trust and save the certificate\n") - - assert.NilError(t, g.Wait()) - - cert, err := os.ReadFile("testdata/pki/localhost.crt") - assert.NilError(t, err) - - // Check the client config - cfg, err = readConfig() - assert.NilError(t, err) - expected := []ClientHostConfig{ - { - Name: "admin@example.com", - AccessKey: "any-access-key", - PolymorphicID: "any-id", - Host: srv.Addrs.HTTPS.String(), - Expires: api.Time(time.Now().UTC().Add(opts.SessionDuration)), - Current: true, - TrustedCertificate: cert, - }, - } - assert.DeepEqual(t, cfg.Hosts, expected, cmpClientHostConfig) - }) } var cmpClientHostConfig = cmp.Options{ @@ -277,3 +196,121 @@ func setupCertManager(t *testing.T, dir string, serverName string) { err = cache.Put(ctx, serverName+".crt", cert) assert.NilError(t, err) } + +func TestLoginCmd_PromptForTLSVerify(t *testing.T) { + dir := t.TempDir() + t.Setenv("HOME", dir) + t.Setenv("USERPROFILE", dir) // Windows + // k8s.io/tools/clientcmd reads HOME at import time, so this must be patched too + kubeConfigPath := filepath.Join(dir, "kube.config") + t.Setenv("KUBECONFIG", kubeConfigPath) + + opts := defaultServerOptions(dir) + accessKey := "0000000001.adminadminadminadmin1234" + opts.Users = []server.User{ + {Name: "admin@example.com", AccessKey: accessKey}, + } + opts.Addr = server.ListenerOptions{HTTPS: "127.0.0.1:0", HTTP: "127.0.0.1:0"} + srv, err := server.New(opts) + assert.NilError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + setupCertManager(t, opts.TLSCache, srv.Addrs.HTTPS.String()) + go func() { + assert.Check(t, srv.Run(ctx)) + }() + + runStep(t, "reject server certificate", func(t *testing.T) { + ctx, cancel := context.WithCancel(ctx) + t.Cleanup(cancel) + + console := newConsole(t) + ctx = PatchCLIWithPTY(ctx, console.Tty()) + + g, ctx := errgroup.WithContext(ctx) + g.Go(func() error { + // TODO: why isn't this working without --non-interactive=false? the other test works + return Run(ctx, "login", "--non-interactive=false", srv.Addrs.HTTPS.String()) + }) + exp := expector{console: console} + exp.ExpectString(t, "verify the certificate can be trusted") + exp.Send(t, "NO\n") + + assert.ErrorIs(t, g.Wait(), terminal.InterruptErr) + + // Check we haven't persisted any certificates + cfg, err := readConfig() + assert.NilError(t, err) + assert.Equal(t, len(cfg.Hosts), 0) + }) + + runStep(t, "trust server certificate", func(t *testing.T) { + ctx, cancel := context.WithCancel(ctx) + t.Cleanup(cancel) + + console := newConsole(t) + ctx = PatchCLIWithPTY(ctx, console.Tty()) + + g, ctx := errgroup.WithContext(ctx) + g.Go(func() error { + // TODO: why isn't this working without --non-interactive=false? the other test works + return Run(ctx, "login", "--non-interactive=false", "--key", accessKey, srv.Addrs.HTTPS.String()) + }) + exp := expector{console: console} + exp.ExpectString(t, "verify the certificate can be trusted") + exp.Send(t, "TRUST\n") + + assert.NilError(t, g.Wait()) + + cert, err := os.ReadFile("testdata/pki/localhost.crt") + assert.NilError(t, err) + + // Check the client config + cfg, err := readConfig() + assert.NilError(t, err) + expected := []ClientHostConfig{ + { + Name: "admin@example.com", + AccessKey: "any-access-key", + PolymorphicID: "any-id", + Host: srv.Addrs.HTTPS.String(), + Expires: api.Time(time.Now().UTC().Add(opts.SessionDuration)), + Current: true, + TrustedCertificate: string(cert), + }, + } + assert.DeepEqual(t, cfg.Hosts, expected, cmpClientHostConfig) + }) + + runStep(t, "next login should still trust the server", func(t *testing.T) { + ctx, cancel := context.WithCancel(ctx) + t.Cleanup(cancel) + + err := Run(ctx, "logout") + assert.NilError(t, err) + + err = Run(ctx, "login", "--key", accessKey, srv.Addrs.HTTPS.String()) + assert.NilError(t, err) + + cert, err := os.ReadFile("testdata/pki/localhost.crt") + assert.NilError(t, err) + + // Check the client config + cfg, err := readConfig() + assert.NilError(t, err) + expected := []ClientHostConfig{ + { + Name: "admin@example.com", + AccessKey: "any-access-key", + PolymorphicID: "any-id", + Host: srv.Addrs.HTTPS.String(), + Expires: api.Time(time.Now().UTC().Add(opts.SessionDuration)), + Current: true, + TrustedCertificate: string(cert), + }, + } + assert.DeepEqual(t, cfg.Hosts, expected, cmpClientHostConfig) + }) +}