From f0abc35c2e70ecd0f224d5b81f20da6b93864c49 Mon Sep 17 00:00:00 2001 From: Daniel Nephin Date: Fri, 3 Jun 2022 16:25:42 -0400 Subject: [PATCH] improve: pem encode the certificate And resolve a few TODOs around errors --- internal/cmd/cmd.go | 14 ++++++++------ internal/cmd/cmd_test.go | 1 + internal/cmd/config.go | 3 ++- internal/cmd/login.go | 27 ++++++++++++++++++++++----- internal/cmd/login_test.go | 8 ++------ 5 files changed, 35 insertions(+), 18 deletions(-) diff --git a/internal/cmd/cmd.go b/internal/cmd/cmd.go index b1a117e8cf..fd9fd5ea75 100644 --- a/internal/cmd/cmd.go +++ b/internal/cmd/cmd.go @@ -83,7 +83,7 @@ func defaultAPIClient() (*api.Client, error) { return apiClient(config.Host, config.AccessKey, httpTransportForHostConfig(config)) } -// TODO: accept a url.URL to remove the error return value +// TODO: accept a url.URL to remove the error return value (github.com/infrahq/infra/issues/1627) func apiClient(host string, accessKey string, transport *http.Transport) (*api.Client, error) { u, err := urlx.Parse(host) if err != nil { @@ -110,15 +110,17 @@ func apiClient(host string, accessKey string, transport *http.Transport) (*api.C func httpTransportForHostConfig(config *ClientHostConfig) *http.Transport { pool, err := x509.SystemCertPool() if err != nil { - // TODO: print warning about this case + logging.S.Warnf("Failed to load trusted certificates from system: %v", err) pool = x509.NewCertPool() } if len(config.TrustedCertificate) > 0 { - // TODO: log or return the error - // TODO: read this in PEM format - cert, _ := x509.ParseCertificate(config.TrustedCertificate) - pool.AddCert(cert) + cert, err := x509.ParseCertificate(pemDecodeCertificate(config.TrustedCertificate)) + if err != nil { + logging.S.Warnf("Failed to read trusted certificates for server: %v", err) + } else { + pool.AddCert(cert) + } } return &http.Transport{ diff --git a/internal/cmd/cmd_test.go b/internal/cmd/cmd_test.go index 416699eb7e..85d6c4d800 100644 --- a/internal/cmd/cmd_test.go +++ b/internal/cmd/cmd_test.go @@ -255,6 +255,7 @@ func newTestClientConfig(srv *httptest.Server, user api.User) ClientConfig { PolymorphicID: uid.NewIdentityPolymorphicID(user.ID), Name: user.Name, Host: srv.Listener.Addr().String(), + // TODO: change to using TrustedCertificate SkipTLSVerify: true, AccessKey: "the-access-key", Expires: api.Time(time.Now().Add(time.Hour)), diff --git a/internal/cmd/config.go b/internal/cmd/config.go index ae7ad149c8..f599e75748 100644 --- a/internal/cmd/config.go +++ b/internal/cmd/config.go @@ -27,7 +27,8 @@ type ClientHostConfig struct { ProviderID uid.ID `json:"provider-id,omitempty"` Expires api.Time `json:"expires"` Current bool `json:"current"` - // TODO: should we store this in a separate file? + // TrustedCertificate is the PEM encoded TLS certificate used by the server + // that was verified and trusted by the user as part of login. TrustedCertificate []byte `json:"trusted-certificate"` } diff --git a/internal/cmd/login.go b/internal/cmd/login.go index b28ce62ec7..c42ee0a15e 100644 --- a/internal/cmd/login.go +++ b/internal/cmd/login.go @@ -5,6 +5,7 @@ import ( "crypto/sha256" "crypto/tls" "crypto/x509" + "encoding/pem" "errors" "fmt" "net/http" @@ -229,8 +230,9 @@ func updateInfraConfig(lc loginClient, loginReq *api.LoginRequest, loginRes *api return fmt.Errorf("Could not update config due to an internal error") } clientHostConfig.SkipTLSVerify = t.TLSClientConfig.InsecureSkipVerify - // TODO: save this in PEM format - clientHostConfig.TrustedCertificate = lc.TrustedCertificate + if len(lc.TrustedCertificate) > 0 { + clientHostConfig.TrustedCertificate = pemEncodeCertificate(lc.TrustedCertificate) + } if loginReq.OIDC != nil { clientHostConfig.ProviderID = loginReq.OIDC.ProviderID @@ -249,6 +251,15 @@ func updateInfraConfig(lc loginClient, loginReq *api.LoginRequest, loginRes *api return nil } +func pemEncodeCertificate(raw []byte) []byte { + return pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: raw}) +} + +func pemDecodeCertificate(raw []byte) []byte { + block, _ := pem.Decode(raw) + return block.Bytes +} + func oidcflow(host string, clientId string) (string, error) { // find out what the authorization endpoint is provider, err := oidc.NewProvider(context.Background(), fmt.Sprintf("https://%s", host)) @@ -402,7 +413,9 @@ func newLoginClient(cli *CLI, options loginCmdOptions) (loginClient, error) { } pool.AddCert(uaErr.Cert) transport := &http.Transport{ - TLSClientConfig: &tls.Config{RootCAs: pool}, + TLSClientConfig: &tls.Config{ + RootCAs: pool, + }, } c.APIClient, err = apiClient(options.Server, "", transport) c.TrustedCertificate = uaErr.Cert.Raw @@ -545,7 +558,11 @@ func promptVerifyTLSCert(cli *CLI, cert *x509.Certificate) error { // TODO: improve this message // TODO: use color/bold to highlight important parts // TODO: test format with golden - fmt.Fprintf(cli.Stderr, `The certificate presented by the server could not be automatically verified. + fmt.Fprintf(cli.Stderr, ` +The certificate presented by the server is not trusted by your operating system. It +could not be automatically verified. + +Certificate Subject: %[1]s Issuer: %[2]s @@ -597,7 +614,7 @@ to manually verify the certificate can be trusted. // TODO: move this to a shared place. func fingerprint(cert *x509.Certificate) string { raw := sha256.Sum256(cert.Raw) - return strings.Replace(fmt.Sprintf("% x", raw), " ", ":", -1) + return strings.ReplaceAll(fmt.Sprintf("% x", raw), " ", ":") } // Returns the host address of the Infra server that user would like to log into diff --git a/internal/cmd/login_test.go b/internal/cmd/login_test.go index c1037474a1..b2f16a2cfe 100644 --- a/internal/cmd/login_test.go +++ b/internal/cmd/login_test.go @@ -2,7 +2,6 @@ package cmd import ( "context" - "encoding/pem" "os" "path/filepath" "testing" @@ -162,9 +161,6 @@ func TestLoginCmd(t *testing.T) { cert, err := os.ReadFile("testdata/pki/localhost.crt") assert.NilError(t, err) - // TODO: remove once stored as PEM encoded - block, _ := pem.Decode(cert) - // Check the client config cfg, err = readConfig() assert.NilError(t, err) @@ -176,7 +172,7 @@ func TestLoginCmd(t *testing.T) { Host: srv.Addrs.HTTPS.String(), Expires: api.Time(time.Now().UTC().Add(opts.SessionDuration)), Current: true, - TrustedCertificate: block.Bytes, + TrustedCertificate: cert, }, } // TODO: where is the extra entry coming from? @@ -227,7 +223,7 @@ func newConsole(t *testing.T) *expect.Console { pseudoTY, tty, err := pty.Open() assert.NilError(t, err, "failed to open pseudo tty") - timeout := time.Hour // TODO: time.Second + timeout := time.Second if os.Getenv("CI") != "" { // CI takes much longer than local dev, use a much longer timeout timeout = 20 * time.Second