Skip to content

Commit

Permalink
improve: pem encode the certificate
Browse files Browse the repository at this point in the history
And resolve a few TODOs around errors
  • Loading branch information
dnephin committed Jun 3, 2022
1 parent 585291b commit f0abc35
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 18 deletions.
14 changes: 8 additions & 6 deletions internal/cmd/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ func defaultAPIClient() (*api.Client, error) {
return apiClient(config.Host, config.AccessKey, httpTransportForHostConfig(config))
}

// TODO: accept a url.URL to remove the error return value
// TODO: accept a url.URL to remove the error return value (github.com/infrahq/infra/issues/1627)
func apiClient(host string, accessKey string, transport *http.Transport) (*api.Client, error) {
u, err := urlx.Parse(host)
if err != nil {
Expand All @@ -110,15 +110,17 @@ func apiClient(host string, accessKey string, transport *http.Transport) (*api.C
func httpTransportForHostConfig(config *ClientHostConfig) *http.Transport {
pool, err := x509.SystemCertPool()
if err != nil {
// TODO: print warning about this case
logging.S.Warnf("Failed to load trusted certificates from system: %v", err)
pool = x509.NewCertPool()
}

if len(config.TrustedCertificate) > 0 {
// TODO: log or return the error
// TODO: read this in PEM format
cert, _ := x509.ParseCertificate(config.TrustedCertificate)
pool.AddCert(cert)
cert, err := x509.ParseCertificate(pemDecodeCertificate(config.TrustedCertificate))
if err != nil {
logging.S.Warnf("Failed to read trusted certificates for server: %v", err)
} else {
pool.AddCert(cert)
}
}

return &http.Transport{
Expand Down
1 change: 1 addition & 0 deletions internal/cmd/cmd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ func newTestClientConfig(srv *httptest.Server, user api.User) ClientConfig {
PolymorphicID: uid.NewIdentityPolymorphicID(user.ID),
Name: user.Name,
Host: srv.Listener.Addr().String(),
// TODO: change to using TrustedCertificate
SkipTLSVerify: true,
AccessKey: "the-access-key",
Expires: api.Time(time.Now().Add(time.Hour)),
Expand Down
3 changes: 2 additions & 1 deletion internal/cmd/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ type ClientHostConfig struct {
ProviderID uid.ID `json:"provider-id,omitempty"`
Expires api.Time `json:"expires"`
Current bool `json:"current"`
// TODO: should we store this in a separate file?
// TrustedCertificate is the PEM encoded TLS certificate used by the server
// that was verified and trusted by the user as part of login.
TrustedCertificate []byte `json:"trusted-certificate"`
}

Expand Down
27 changes: 22 additions & 5 deletions internal/cmd/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"net/http"
Expand Down Expand Up @@ -229,8 +230,9 @@ func updateInfraConfig(lc loginClient, loginReq *api.LoginRequest, loginRes *api
return fmt.Errorf("Could not update config due to an internal error")
}
clientHostConfig.SkipTLSVerify = t.TLSClientConfig.InsecureSkipVerify
// TODO: save this in PEM format
clientHostConfig.TrustedCertificate = lc.TrustedCertificate
if len(lc.TrustedCertificate) > 0 {
clientHostConfig.TrustedCertificate = pemEncodeCertificate(lc.TrustedCertificate)
}

if loginReq.OIDC != nil {
clientHostConfig.ProviderID = loginReq.OIDC.ProviderID
Expand All @@ -249,6 +251,15 @@ func updateInfraConfig(lc loginClient, loginReq *api.LoginRequest, loginRes *api
return nil
}

func pemEncodeCertificate(raw []byte) []byte {
return pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: raw})
}

func pemDecodeCertificate(raw []byte) []byte {
block, _ := pem.Decode(raw)
return block.Bytes
}

func oidcflow(host string, clientId string) (string, error) {
// find out what the authorization endpoint is
provider, err := oidc.NewProvider(context.Background(), fmt.Sprintf("https://%s", host))
Expand Down Expand Up @@ -402,7 +413,9 @@ func newLoginClient(cli *CLI, options loginCmdOptions) (loginClient, error) {
}
pool.AddCert(uaErr.Cert)
transport := &http.Transport{
TLSClientConfig: &tls.Config{RootCAs: pool},
TLSClientConfig: &tls.Config{
RootCAs: pool,
},
}
c.APIClient, err = apiClient(options.Server, "", transport)
c.TrustedCertificate = uaErr.Cert.Raw
Expand Down Expand Up @@ -545,7 +558,11 @@ func promptVerifyTLSCert(cli *CLI, cert *x509.Certificate) error {
// TODO: improve this message
// TODO: use color/bold to highlight important parts
// TODO: test format with golden
fmt.Fprintf(cli.Stderr, `The certificate presented by the server could not be automatically verified.
fmt.Fprintf(cli.Stderr, `
The certificate presented by the server is not trusted by your operating system. It
could not be automatically verified.
Certificate
Subject: %[1]s
Issuer: %[2]s
Expand Down Expand Up @@ -597,7 +614,7 @@ to manually verify the certificate can be trusted.
// TODO: move this to a shared place.
func fingerprint(cert *x509.Certificate) string {
raw := sha256.Sum256(cert.Raw)
return strings.Replace(fmt.Sprintf("% x", raw), " ", ":", -1)
return strings.ReplaceAll(fmt.Sprintf("% x", raw), " ", ":")
}

// Returns the host address of the Infra server that user would like to log into
Expand Down
8 changes: 2 additions & 6 deletions internal/cmd/login_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package cmd

import (
"context"
"encoding/pem"
"os"
"path/filepath"
"testing"
Expand Down Expand Up @@ -162,9 +161,6 @@ func TestLoginCmd(t *testing.T) {
cert, err := os.ReadFile("testdata/pki/localhost.crt")
assert.NilError(t, err)

// TODO: remove once stored as PEM encoded
block, _ := pem.Decode(cert)

// Check the client config
cfg, err = readConfig()
assert.NilError(t, err)
Expand All @@ -176,7 +172,7 @@ func TestLoginCmd(t *testing.T) {
Host: srv.Addrs.HTTPS.String(),
Expires: api.Time(time.Now().UTC().Add(opts.SessionDuration)),
Current: true,
TrustedCertificate: block.Bytes,
TrustedCertificate: cert,
},
}
// TODO: where is the extra entry coming from?
Expand Down Expand Up @@ -227,7 +223,7 @@ func newConsole(t *testing.T) *expect.Console {
pseudoTY, tty, err := pty.Open()
assert.NilError(t, err, "failed to open pseudo tty")

timeout := time.Hour // TODO: time.Second
timeout := time.Second
if os.Getenv("CI") != "" {
// CI takes much longer than local dev, use a much longer timeout
timeout = 20 * time.Second
Expand Down

0 comments on commit f0abc35

Please sign in to comment.