From 49c4e4a0ec35d2f9c59d6ad767a90112a413d125 Mon Sep 17 00:00:00 2001 From: Iryna Shustava Date: Wed, 4 Mar 2020 22:37:29 -0800 Subject: [PATCH] Support cloud auto-join --- go.mod | 1 + subcommand/get-consul-client-ca/command.go | 100 +++++++++-- .../get-consul-client-ca/command_test.go | 155 +++++++++++++++++- 3 files changed, 238 insertions(+), 18 deletions(-) diff --git a/go.mod b/go.mod index c6a3d7b81d..5880edc473 100644 --- a/go.mod +++ b/go.mod @@ -17,6 +17,7 @@ require ( github.com/hashicorp/consul v1.7.1 github.com/hashicorp/consul/api v1.4.0 github.com/hashicorp/consul/sdk v0.4.0 + github.com/hashicorp/go-discover v0.0.0-20191202160150-7ec2cfbda7a2 github.com/hashicorp/go-hclog v0.12.0 github.com/hashicorp/go-multierror v1.0.0 github.com/hashicorp/golang-lru v0.5.3 // indirect diff --git a/subcommand/get-consul-client-ca/command.go b/subcommand/get-consul-client-ca/command.go index 3ce2b60e22..449d115b76 100644 --- a/subcommand/get-consul-client-ca/command.go +++ b/subcommand/get-consul-client-ca/command.go @@ -5,11 +5,15 @@ import ( "fmt" "io/ioutil" "os" + "strings" "sync" "time" "github.com/hashicorp/consul/api" "github.com/hashicorp/consul/command/flags" + "github.com/hashicorp/consul/lib" + "github.com/hashicorp/go-discover" + discoverk8s "github.com/hashicorp/go-discover/provider/k8s" "github.com/hashicorp/go-hclog" "github.com/mitchellh/cli" ) @@ -20,21 +24,24 @@ type Command struct { flags *flag.FlagSet flagOutputFile string - flagHttpAddr string + flagServerAddr string flagCAFile string flagTLSServerName string flagLogLevel string once sync.Once help string + + providers map[string]discover.Provider } func (c *Command) init() { c.flags = flag.NewFlagSet("", flag.ContinueOnError) c.flags.StringVar(&c.flagOutputFile, "output-file", "", "The path to the file where to put the Consul client's CA certificate.") - c.flags.StringVar(&c.flagHttpAddr, "http-addr", "", - "The HTTP address of the Consul server. This can also be provided via the CONSUL_HTTP_ADDR environment variable.") + c.flags.StringVar(&c.flagServerAddr, "server-addr", "", + "The URL, IP or the Cloud auto-join string pointing to the Consul server. The server must be running with TLS enabled."+ + "The default HTTPS port 8501 will be used for all connections.") c.flags.StringVar(&c.flagCAFile, "ca-file", "", "The path to the CA file to use when making requests to the Consul server. This can also be provided via the CONSUL_CACERT environment variable") c.flags.StringVar(&c.flagTLSServerName, "tls-server-name", "", @@ -61,13 +68,6 @@ func (c *Command) Run(args []string) int { return 1 } - // create Consul client - consulClient, err := c.consulClient() - if err != nil { - c.UI.Error(fmt.Sprintf("Error initializing Consul client: %s", err)) - return 1 - } - // create a logger level := hclog.LevelFromString(c.flagLogLevel) if level == hclog.NoLevel { @@ -79,6 +79,13 @@ func (c *Command) Run(args []string) int { Output: os.Stderr, }) + // create Consul client + consulClient, err := c.consulClient(logger) + if err != nil { + c.UI.Error(fmt.Sprintf("Error initializing Consul client: %s", err)) + return 1 + } + // Get the active CA root from Consul // Wait until it gets a successful response var activeRoot string @@ -90,7 +97,7 @@ func (c *Command) Run(args []string) int { continue } - activeRoot, err = c.getActiveRoot(caRoots) + activeRoot, err = getActiveRoot(caRoots) if err != nil { logger.Info("Could not get an active root", "err", err) time.Sleep(1 * time.Second) @@ -104,14 +111,56 @@ func (c *Command) Run(args []string) int { return 1 } + c.UI.Info(fmt.Sprintf("Successfully written Consul client CA to: %s", c.flagOutputFile)) return 0 } -func (c *Command) consulClient() (*api.Client, error) { +func (c *Command) consulClient(logger hclog.Logger) (*api.Client, error) { cfg := api.DefaultConfig() - if c.flagHttpAddr != "" { - cfg.Address = c.flagHttpAddr + + // First, check if the server address is a cloud auto-join string. + // If it is, discover server addresses through the cloud provider. + if strings.Contains(c.flagServerAddr, "provider=") { + disco, err := c.newDiscover() + if err != nil { + return nil, err + } + logger.Debug("using cloud auto-join with", c.flagServerAddr) + servers, err := disco.Addrs(c.flagServerAddr, logger.StandardLogger(&hclog.StandardLoggerOptions{ + InferLevels: true, + })) + if err != nil { + return nil, err + } + + // check if we discovered any servers + if len(servers) == 0 { + return nil, fmt.Errorf("could not discover any Consul servers with %q", c.flagServerAddr) + } + + logger.Debug("discovered servers", strings.Join(servers, " ")) + + // Pick the first server from the list, + // ignoring the port since we need to use HTTP API + firstServer := strings.SplitN(servers[0], ":", 2)[0] + cfg.Address = fmt.Sprintf("%s:8501", firstServer) + cfg.Scheme = "https" + } else { + // check if the server URL is missing a port + host := strings.TrimPrefix(c.flagServerAddr, "https://") + host = strings.TrimPrefix(c.flagServerAddr, "http://") + parts := strings.SplitN(host, ":", 2) + + // Use the default HTTPS port if port is missing. + // Otherwise, use the address the user has provided. + if len(parts) == 1 { + cfg.Address = fmt.Sprintf("%s:8501", c.flagServerAddr) + cfg.Scheme = "https" + } else { + cfg.Address = c.flagServerAddr + } } + if c.flagCAFile != "" { cfg.TLSConfig.CAFile = c.flagCAFile } @@ -122,7 +171,28 @@ func (c *Command) consulClient() (*api.Client, error) { return api.NewClient(cfg) } -func (c *Command) getActiveRoot(roots *api.CARootList) (string, error) { +// newDiscover initializes the new Discover object +// set up with all predefined providers, as well as +// the k8s provider. +func (c *Command) newDiscover() (*discover.Discover, error) { + if c.providers == nil { + c.providers = make(map[string]discover.Provider) + } + + for k, v := range discover.Providers { + c.providers[k] = v + } + c.providers["k8s"] = &discoverk8s.Provider{} + + return discover.New( + discover.WithUserAgent(lib.UserAgent()), + discover.WithProviders(c.providers), + ) +} + +// getActiveRoot returns the currently active root +// from the roots list, otherwise returns error. +func getActiveRoot(roots *api.CARootList) (string, error) { if roots == nil { return "", fmt.Errorf("ca roots is nil") } diff --git a/subcommand/get-consul-client-ca/command_test.go b/subcommand/get-consul-client-ca/command_test.go index f52de08ee0..790d6681c8 100644 --- a/subcommand/get-consul-client-ca/command_test.go +++ b/subcommand/get-consul-client-ca/command_test.go @@ -2,8 +2,13 @@ package getconsulclientca import ( "crypto" + "crypto/x509" "fmt" + "github.com/hashicorp/go-discover" "io/ioutil" + "log" + "net" + "os" "testing" "time" @@ -54,7 +59,7 @@ func TestRun(t *testing.T) { // run the command exitCode := cmd.Run([]string{ - "-http-addr", a.HTTPAddr, + "-server-addr", a.HTTPAddr, "-output-file", outputFile.Name(), }) require.Equal(t, 0, exitCode) @@ -97,7 +102,7 @@ func TestRun_ConsulServerAvailableLater(t *testing.T) { exitCode := -1 go func() { exitCode = cmd.Run([]string{ - "-http-addr", fmt.Sprintf("http://127.0.0.1:%d", randomPorts[1]), + "-server-addr", fmt.Sprintf("http://127.0.0.1:%d", randomPorts[1]), "-output-file", outputFile.Name(), }) require.Equal(t, 0, exitCode) @@ -193,7 +198,7 @@ func TestRun_GetsOnlyActiveRoot(t *testing.T) { }) exitCode := cmd.Run([]string{ - "-http-addr", a.HTTPAddr, + "-server-addr", a.HTTPAddr, "-output-file", outputFile.Name(), }) require.Equal(t, 0, exitCode) @@ -219,6 +224,80 @@ func TestRun_GetsOnlyActiveRoot(t *testing.T) { require.Equal(t, expectedCARoot, string(actualCARoot)) } +// Test that when using cloud auto-join +// it uses the provider to get the address of the server +func TestRun_WithProvider(t *testing.T) { + t.Parallel() + outputFile, err := ioutil.TempFile("", "ca") + require.NoError(t, err) + + ui := cli.NewMockUi() + provider := &fakeProvider{} + cmd := Command{ + UI: ui, + providers: map[string]discover.Provider{"fake": provider}, + } + + caFile, certFile, keyFile, cleanup := generateServerCerts(t) + defer cleanup() + + randomPorts := freeport.MustTake(5) + // start the test server + a, err := testutil.NewTestServerConfigT(t, func(c *testutil.TestServerConfig) { + c.Connect = map[string]interface{}{ + "enabled": true, + } + c.CAFile = caFile + c.CertFile = certFile + c.KeyFile = keyFile + c.Ports = &testutil.TestPortConfig{ + DNS: randomPorts[0], + HTTP: randomPorts[1], + HTTPS: 8501, + SerfLan: randomPorts[2], + SerfWan: randomPorts[3], + Server: randomPorts[4], + } + }) + require.NoError(t, err) + defer a.Stop() + + // run the command + exitCode := cmd.Run([]string{ + "-server-addr", "provider=fake", + "-tls-server-name", "localhost", + "-output-file", outputFile.Name(), + "-ca-file", caFile, + }) + require.Equal(t, 0, exitCode, ui.ErrorWriter.String()) + + // check that the provider has been called + require.Equal(t, 1, provider.addrsNumCalls, "provider's Addrs method was not called") + + client, err := api.NewClient(&api.Config{ + Address: a.HTTPSAddr, + Scheme: "https", + TLSConfig: api.TLSConfig{ + CAFile: caFile, + }, + }) + require.NoError(t, err) + + // get the actual root ca cert from consul + roots, _, err := client.Agent().ConnectCARoots(nil) + require.NoError(t, err) + require.NotNil(t, roots) + require.NotNil(t, roots.Roots) + require.Len(t, roots.Roots, 1) + require.True(t, roots.Roots[0].Active) + expectedCARoot := roots.Roots[0].RootCertPEM + + // read the file contents + actualCARoot, err := ioutil.ReadFile(outputFile.Name()) + require.NoError(t, err) + require.Equal(t, expectedCARoot, string(actualCARoot)) +} + // generateCA generates Consul CA // and returns cert and key as pem strings. func generateCA(t *testing.T) (caPem, keyPem string) { @@ -237,3 +316,73 @@ func generateCA(t *testing.T) (caPem, keyPem string) { return } + +// generateServerCerts generates Consul CA +// and a server certificate and saves them to temp files. +// It returns file names in this order: +// CA certificate, server certificate, and server key. +// Note that it's the responsibility of the caller to +// remove the temporary files created by this function. +func generateServerCerts(t *testing.T) (string, string, string, func()) { + require := require.New(t) + + caFile, err := ioutil.TempFile("", "ca") + require.NoError(err) + + certFile, err := ioutil.TempFile("", "cert") + require.NoError(err) + + certKeyFile, err := ioutil.TempFile("", "key") + require.NoError(err) + + // Generate CA + sn, err := tlsutil.GenerateSerialNumber() + require.NoError(err) + + s, _, err := tlsutil.GeneratePrivateKey() + require.NoError(err) + + constraints := []string{"consul", "localhost"} + ca, err := tlsutil.GenerateCA(s, sn, 1, constraints) + require.NoError(err) + + // Generate Server Cert + name := fmt.Sprintf("server.%s.%s", "dc1", "consul") + DNSNames := []string{name, "localhost"} + IPAddresses := []net.IP{net.ParseIP("127.0.0.1")} + extKeyUsage := []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth} + + sn, err = tlsutil.GenerateSerialNumber() + require.NoError(err) + + pub, priv, err := tlsutil.GenerateCert(s, ca, sn, name, 1, DNSNames, IPAddresses, extKeyUsage) + require.NoError(err) + + // Write certs and key to files + _, err = caFile.WriteString(ca) + require.NoError(err) + _, err = certFile.WriteString(pub) + require.NoError(err) + _, err = certKeyFile.WriteString(priv) + require.NoError(err) + + cleanupFunc := func() { + os.Remove(caFile.Name()) + os.Remove(certFile.Name()) + os.Remove(certKeyFile.Name()) + } + return caFile.Name(), certFile.Name(), certKeyFile.Name(), cleanupFunc +} + +type fakeProvider struct { + addrsNumCalls int +} + +func (p *fakeProvider) Addrs(args map[string]string, l *log.Logger) ([]string, error) { + p.addrsNumCalls++ + return []string{"127.0.0.1"}, nil +} + +func (p *fakeProvider) Help() string { + return "fake-provider help" +}