Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cli: prompt to accept the server TLS certificate for manual verification #2177

Merged
merged 11 commits into from
Jun 20, 2022
6 changes: 3 additions & 3 deletions internal/certs/certs.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func GenerateCertificate(hosts []string, caCert *x509.Certificate, caKey crypto.
}

keyBytes := pemEncodePrivateKey(x509.MarshalPKCS1PrivateKey(key))
return pemEncodeCertificate(certBytes), keyBytes, nil
return PEMEncodeCertificate(certBytes), keyBytes, nil
}

func SelfSignedOrLetsEncryptCert(manager *autocert.Manager, serverName string) func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
Expand Down Expand Up @@ -136,9 +136,9 @@ func pemDecode(raw []byte) []byte {
return block.Bytes
}

// pemEncodeCertificate accepts the bytes of a x509 certificate in ASN.1 DER form
// PEMEncodeCertificate accepts the bytes of a x509 certificate in ASN.1 DER form
// and returns a PEM encoded representation of that certificate.
func pemEncodeCertificate(raw []byte) []byte {
func PEMEncodeCertificate(raw []byte) []byte {
return pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: raw})
}

Expand Down
39 changes: 29 additions & 10 deletions internal/cmd/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package cmd
import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -77,23 +78,41 @@ func defaultAPIClient() (*api.Client, error) {
return nil, err
}

return apiClient(config.Host, config.AccessKey, config.SkipTLSVerify), nil
return apiClient(config.Host, config.AccessKey, httpTransportForHostConfig(config)), nil
}

func apiClient(host string, accessKey string, skipTLSVerify bool) *api.Client {
func apiClient(host string, accessKey string, transport *http.Transport) *api.Client {
return &api.Client{
Name: "cli",
Version: internal.Version,
URL: "https://" + host,
AccessKey: accessKey,
HTTP: http.Client{
Timeout: 60 * time.Second,
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
//nolint:gosec // We may purposely set insecureskipverify via a flag
InsecureSkipVerify: skipTLSVerify,
},
},
Timeout: 60 * time.Second,
Transport: transport,
},
}
}

func httpTransportForHostConfig(config *ClientHostConfig) *http.Transport {
pool, err := x509.SystemCertPool()
if err != nil {
logging.S.Warnf("Failed to load trusted certificates from system: %v", err)
pool = x509.NewCertPool()
}

if config.TrustedCertificate != "" {
ok := pool.AppendCertsFromPEM([]byte(config.TrustedCertificate))
if !ok {
logging.S.Warnf("Failed to read trusted certificates for server")
}
}

return &http.Transport{
TLSClientConfig: &tls.Config{
//nolint:gosec // We may purposely set insecureskipverify via a flag
BruceMacD marked this conversation as resolved.
Show resolved Hide resolved
InsecureSkipVerify: config.SkipTLSVerify,
RootCAs: pool,
},
}
}
Expand Down Expand Up @@ -134,7 +153,7 @@ $ infra use development.kube-system`,
return err
}

err = updateKubeconfig(client, id)
err = updateKubeConfig(client, id)
if err != nil {
return err
}
Expand Down
17 changes: 9 additions & 8 deletions internal/cmd/cmd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"gotest.tools/v3/fs"

"github.com/infrahq/infra/api"
"github.com/infrahq/infra/internal/certs"
"github.com/infrahq/infra/internal/connector"
"github.com/infrahq/infra/uid"
)
Expand Down Expand Up @@ -267,16 +268,16 @@ func newTestClientConfig(srv *httptest.Server, user api.User) ClientConfig {
user.ID = uid.New()
}
return ClientConfig{
Version: "0.3",
Version: clientConfigVersion,
Hosts: []ClientHostConfig{
{
PolymorphicID: uid.NewIdentityPolymorphicID(user.ID),
Name: user.Name,
Host: srv.Listener.Addr().String(),
SkipTLSVerify: true,
AccessKey: "the-access-key",
Expires: api.Time(time.Now().Add(time.Hour)),
Current: true,
PolymorphicID: uid.NewIdentityPolymorphicID(user.ID),
Name: user.Name,
Host: srv.Listener.Addr().String(),
TrustedCertificate: string(certs.PEMEncodeCertificate(srv.Certificate().Raw)),
AccessKey: "the-access-key",
Expires: api.Time(time.Now().Add(time.Hour)),
Current: true,
},
},
}
Expand Down
18 changes: 10 additions & 8 deletions internal/cmd/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,16 @@ import (
"github.com/infrahq/infra/uid"
)

// current: v0.3
// clientConfigVersion is the current version of the client configuration file.
// Use this constant when referring to a value in tests or code that should
// always use latest.
const clientConfigVersion = "0.3"

type ClientConfig struct {
Version string `json:"version"`
Hosts []ClientHostConfig `json:"hosts"`
}

// current: v0.3
type ClientHostConfig struct {
PolymorphicID uid.PolymorphicID `json:"polymorphic-id"` // identity pid TODO: name this. what's it an ID of?
Name string `json:"name"` // identity name
Expand All @@ -27,6 +30,9 @@ type ClientHostConfig struct {
ProviderID uid.ID `json:"provider-id,omitempty"`
Expires api.Time `json:"expires"`
Current bool `json:"current"`
// TrustedCertificate is the PEM encoded TLS certificate used by the server
// that was verified and trusted by the user as part of login.
TrustedCertificate string `json:"trusted-certificate"`
}

// checks if user is logged in to the given session (ClientHostConfig)
Expand Down Expand Up @@ -65,7 +71,7 @@ func readConfig() (*ClientConfig, error) {
}

if len(contents) == 0 {
return &ClientConfig{Version: "0.3"}, nil
return &ClientConfig{Version: clientConfigVersion}, nil
}

config := &ClientConfig{}
Expand Down Expand Up @@ -106,11 +112,7 @@ func writeConfig(config *ClientConfig) error {
return err
}

if err = ioutil.WriteFile(filepath.Join(infraDir, "config"), []byte(contents), 0o600); err != nil {
return err
}

return nil
return ioutil.WriteFile(filepath.Join(infraDir, "config"), contents, 0o600)
}

// Save (create or update) the current hostconfig
Expand Down
5 changes: 0 additions & 5 deletions internal/cmd/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,6 @@ var (
ErrUserNotFound = errors.New(`no user found with this name`)
)

// User facing constant errors: to let user know why their command failed. Not meant for a stack trace, but a readable output of the reason for failure.
var (
ErrTLSNotVerified = errors.New(`The authenticity of the host can't be established.`)
)

type LoginError struct {
Message string
}
Expand Down
4 changes: 2 additions & 2 deletions internal/cmd/kubernetes.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,10 @@ func kubernetesSetContext(cluster, namespace string) error {
return nil
}

func updateKubeconfig(client *api.Client, id uid.ID) error {
func updateKubeConfig(client *api.Client, id uid.ID) error {
destinations, err := client.ListDestinations(api.ListDestinationsRequest{})
if err != nil {
return nil
return err
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for catching this

}

user, err := client.GetUser(id)
Expand Down
7 changes: 4 additions & 3 deletions internal/cmd/list_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ func TestListCmd(t *testing.T) {
assert.Check(t, srv.Run(ctx))
}()

c := apiClient(srv.Addrs.HTTPS.String(), "0000000001.adminadminadminadmin1234", true)
httpTransport := httpTransportForHostConfig(&ClientHostConfig{SkipTLSVerify: true})
c := apiClient(srv.Addrs.HTTPS.String(), "0000000001.adminadminadminadmin1234", httpTransport)

_, err = c.CreateDestination(&api.CreateDestinationRequest{
UniqueID: "space",
Expand Down Expand Up @@ -77,7 +78,7 @@ func TestListCmd(t *testing.T) {
t.Run("with no grants", func(t *testing.T) {
user := userMap["[email protected]"]
err := writeConfig(&ClientConfig{
Version: "0.3",
Version: clientConfigVersion,
Hosts: []ClientHostConfig{
{
PolymorphicID: uid.NewIdentityPolymorphicID(user.ID),
Expand All @@ -103,7 +104,7 @@ func TestListCmd(t *testing.T) {
t.Run("with many grants", func(t *testing.T) {
user := userMap["[email protected]"]
err := writeConfig(&ClientConfig{
Version: "0.3",
Version: clientConfigVersion,
Hosts: []ClientHostConfig{
{
PolymorphicID: uid.NewIdentityPolymorphicID(user.ID),
Expand Down
Loading