diff --git a/BUILD.bazel b/BUILD.bazel index 2d9fb78a2a95..a00e9e60a953 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -118,6 +118,7 @@ load("@bazel_gazelle//:def.bzl", "gazelle") # gazelle:exclude pkg/util/timeutil/pgdate/field_string.go # gazelle:exclude pkg/util/timeutil/pgdate/parsemode_string.go # gazelle:exclude pkg/workload/schemachange/optype_string.go +# gazelle:exclude pkg/geo/wkt/wkt_generated.go # gazelle:exclude pkg/sql/schemachanger/scop/backfill_visitor_generated.go # gazelle:exclude pkg/sql/schemachanger/scop/mutation_visitor_generated.go # gazelle:exclude pkg/sql/schemachanger/scop/validation_visitor_generated.go diff --git a/DEPS.bzl b/DEPS.bzl index 4050acef18c7..eb808cc2c4e0 100644 --- a/DEPS.bzl +++ b/DEPS.bzl @@ -3240,8 +3240,8 @@ def go_deps(): name = "com_github_twpayne_go_geom", build_file_proto_mode = "disable_global", importpath = "github.com/twpayne/go-geom", - sum = "h1:yh2fcro1FLk9uTYi3OSXxtI3JRzaghtsNgaku2ASZbE=", - version = "v1.3.7-0.20210224233516-acd1d64d533a", + sum = "h1:SRMQNnhXCCgFBGAYFnM8iOSMYcOlOwkaTP3pwRCcuOY=", + version = "v1.3.7-0.20210228220813-9d9885b50d3e", ) go_repository( name = "com_github_twpayne_go_kml", diff --git a/go.mod b/go.mod index daaede2e2cdb..c95f0cfc22bc 100644 --- a/go.mod +++ b/go.mod @@ -139,7 +139,7 @@ require ( github.com/spf13/cobra v0.0.5 github.com/spf13/pflag v1.0.5 github.com/stretchr/testify v1.6.1 - github.com/twpayne/go-geom v1.3.7-0.20210224233516-acd1d64d533a + github.com/twpayne/go-geom v1.3.7-0.20210228220813-9d9885b50d3e github.com/wadey/gocovmerge v0.0.0-20160331181800-b5bfa59ec0ad github.com/xdg/scram v0.0.0-20180814205039-7eeb5667e42c github.com/zabawaba99/go-gitignore v0.0.0-20200117185801-39e6bddfb292 diff --git a/go.sum b/go.sum index 8e00b6dd08bb..1f7de340d5d6 100644 --- a/go.sum +++ b/go.sum @@ -964,8 +964,8 @@ github.com/tinylib/msgp v1.1.1/go.mod h1:+d+yLhGm8mzTaHzB+wgMYrodPfmZrzkirds8fDW github.com/tmc/grpc-websocket-proxy v0.0.0-20170815181823-89b8d40f7ca8/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= github.com/tmc/grpc-websocket-proxy v0.0.0-20190109142713-0ad062ec5ee5/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= github.com/tv42/httpunix v0.0.0-20150427012821-b75d8614f926/go.mod h1:9ESjWnEqriFuLhtthL60Sar/7RFoluCcXsuvEwTV5KM= -github.com/twpayne/go-geom v1.3.7-0.20210224233516-acd1d64d533a h1:yh2fcro1FLk9uTYi3OSXxtI3JRzaghtsNgaku2ASZbE= -github.com/twpayne/go-geom v1.3.7-0.20210224233516-acd1d64d533a/go.mod h1:XTyWHR6+l9TUYONbbK4ImUTYbWDCu2ySSPrZmmiA0Pg= +github.com/twpayne/go-geom v1.3.7-0.20210228220813-9d9885b50d3e h1:SRMQNnhXCCgFBGAYFnM8iOSMYcOlOwkaTP3pwRCcuOY= +github.com/twpayne/go-geom v1.3.7-0.20210228220813-9d9885b50d3e/go.mod h1:XTyWHR6+l9TUYONbbK4ImUTYbWDCu2ySSPrZmmiA0Pg= github.com/twpayne/go-kml v1.5.1 h1:RI0JKh/VzdK/d+ZxdJzt8Ar921KMYPfg9qkw7vsbAGw= github.com/twpayne/go-kml v1.5.1/go.mod h1:kz8jAiIz6FIdU2Zjce9qGlVtgFYES9vt7BTPBHf5jl4= github.com/twpayne/go-polyline v1.0.0/go.mod h1:ICh24bcLYBX8CknfvNPKqoTbe+eg+MX1NPyJmSBo7pU= diff --git a/pkg/base/config.go b/pkg/base/config.go index f37ef4e1665f..9970f26f9e7d 100644 --- a/pkg/base/config.go +++ b/pkg/base/config.go @@ -167,10 +167,6 @@ type Config struct { // SSLCertsDir is the path to the certificate/key directory. SSLCertsDir string - // InitToken is a shared initialization token for generating TLS certificates - // across multiple nodes. - InitToken string - // User running this process. It could be the user under which // the server is running or the user passed in client calls. User security.SQLUsername diff --git a/pkg/cli/BUILD.bazel b/pkg/cli/BUILD.bazel index cecb5c12a3d6..318f825c6700 100644 --- a/pkg/cli/BUILD.bazel +++ b/pkg/cli/BUILD.bazel @@ -189,6 +189,7 @@ go_library( "@com_github_cockroachdb_cockroach_go//crdb", "@com_github_cockroachdb_errors//:errors", "@com_github_cockroachdb_errors//oserror", + "@com_github_cockroachdb_logtags//:logtags", "@com_github_cockroachdb_pebble//:pebble", "@com_github_cockroachdb_pebble//tool", "@com_github_cockroachdb_redact//:redact", diff --git a/pkg/cli/cli.go b/pkg/cli/cli.go index bec94116adc9..d25ecc7ddce3 100644 --- a/pkg/cli/cli.go +++ b/pkg/cli/cli.go @@ -250,6 +250,7 @@ func init() { cockroachCmd.AddCommand( startCmd, startSingleNodeCmd, + connectCmd, initCmd, certCmd, // TODO(bilal): Uncomment this when the connect command does something useful. diff --git a/pkg/cli/cli_test.go b/pkg/cli/cli_test.go index 8688dd2dc88a..e1431bae0ea4 100644 --- a/pkg/cli/cli_test.go +++ b/pkg/cli/cli_test.go @@ -1408,6 +1408,7 @@ func TestFlagUsage(t *testing.T) { Available Commands: start start a node in a multi-node cluster start-single-node start a single-node cluster + connect auto-build TLS certificates for use with the start command init initialize a cluster cert create ca, node, and client certs sql open a sql shell diff --git a/pkg/cli/cliflags/flags.go b/pkg/cli/cliflags/flags.go index e99d051a1f46..fa6402e9dce3 100644 --- a/pkg/cli/cliflags/flags.go +++ b/pkg/cli/cliflags/flags.go @@ -717,8 +717,33 @@ Disable use of "external" IO, such as to S3, GCS, or the file system (nodelocal) } InitToken = FlagInfo{ - Name: "init-token", - Description: `Shared token for initialization of node TLS certificates`, + Name: "init-token", + Description: `Shared token for initialization of node TLS certificates. + +This flag is optional for the 'start' command. When omitted, the 'start' +command expects the operator to prepare TLS certificates beforehand using +the 'cert' command. + +This flag must be combined with --num-expected-initial-nodes.`, + } + + NumExpectedInitialNodes = FlagInfo{ + Name: "num-expected-initial-nodes", + Description: `Number of expected nodes during TLS certificate creation, +including the node where the connect command is run. + +This flag must be combined with --init-token.`, + } + + SingleNode = FlagInfo{ + Name: "single-node", + Description: `Prepare the certificates for a subsequent 'start-single-node' +command. The 'connect' command only runs cursory checks on the network +configuration and does not wait for peers to auto-negotiate a common +set of credentials. + +The --single-node flag is exclusive with the --init-num-peers and --init-token +flags.`, } CertsDir = FlagInfo{ diff --git a/pkg/cli/connect.go b/pkg/cli/connect.go index 31ae7c4aae28..e6e35823a9d6 100644 --- a/pkg/cli/connect.go +++ b/pkg/cli/connect.go @@ -13,9 +13,14 @@ package cli import ( "context" "fmt" - "net" + "os" + "github.com/cockroachdb/cockroach/pkg/cli/cliflags" + "github.com/cockroachdb/cockroach/pkg/security" "github.com/cockroachdb/cockroach/pkg/server" + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/cockroachdb/errors" + "github.com/cockroachdb/logtags" "github.com/spf13/cobra" ) @@ -23,7 +28,7 @@ import ( // certificates in the specified certs-dir for use with start. var connectCmd = &cobra.Command{ Use: "connect --certs-dir= --init-token= --join=,,...,", - Short: "build TLS certificates for use with the start command", + Short: "auto-build TLS certificates for use with the start command", Long: ` Connects to other nodes and negotiates an initialization bundle for use with secure inter-node connections. @@ -34,18 +39,115 @@ secure inter-node connections. // runConnect connects to other nodes and negotiates an initialization bundle // for use with secure inter-node connections. -func runConnect(cmd *cobra.Command, args []string) error { +func runConnect(cmd *cobra.Command, args []string) (retErr error) { + if err := validateConnectFlags(cmd, true /* requireExplicitFlags */); err != nil { + return err + } + + // If the node cert already exists, skip all the complexity of setting up + // servers, etc. + cl := security.MakeCertsLocator(baseCfg.SSLCertsDir) + if exists, err := cl.HasNodeCert(); err != nil { + return err + } else if exists { + return errors.Newf("node certificate already exists in %s", baseCfg.SSLCertsDir) + } + + // Ensure that log files are populated when the process terminates. + defer log.Flush() + peers := []string(serverCfg.JoinList) ctx, cancel := context.WithCancel(context.Background()) defer cancel() - listener, err := net.Listen("tcp", fmt.Sprintf("%s:%s", startCtx.serverListenAddr, serverListenPort)) + ctx = logtags.AddTag(ctx, "connect", nil) + + log.Infof(ctx, "validating the command line network arguments") + + // Ensure that the default hostnames / ports are filled in for the + // various address fields in baseCfg. + if err := baseCfg.ValidateAddrs(ctx); err != nil { + return err + } + + log.Ops.Infof(ctx, "starting the initial network listeners") + + // We are creating listeners so that if the host part of the listen + // address means "all interfaces", the Listen call will resolve this + // into a concrete network address. We need all separate listeners + // because the certs will want to use the advertised addresses. + rpcLn, err := server.ListenAndUpdateAddrs(ctx, &baseCfg.Addr, &baseCfg.AdvertiseAddr, "rpc") + if err != nil { + return err + } + log.Ops.Infof(ctx, "started rpc listener at: %s", rpcLn.Addr()) + defer func() { + _ = rpcLn.Close() + }() + httpLn, err := server.ListenAndUpdateAddrs(ctx, &baseCfg.HTTPAddr, &baseCfg.HTTPAdvertiseAddr, "http") if err != nil { return err } + log.Ops.Infof(ctx, "started http listener at: %s", httpLn.Addr()) + if baseCfg.SplitListenSQL { + sqlLn, err := server.ListenAndUpdateAddrs(ctx, &baseCfg.SQLAddr, &baseCfg.SQLAdvertiseAddr, "sql") + if err != nil { + return err + } + log.Ops.Infof(ctx, "started sql listener at: %s", sqlLn.Addr()) + _ = sqlLn.Close() + } + // Note: we want the http listener to remain open while we open the + // SQL listener above to detect port conflict in the configuration + // properly. + _ = httpLn.Close() + defer func() { - _ = listener.Close() + if retErr == nil { + fmt.Println("server certificate generation complete.\n\n" + + "cert files generated in: " + os.ExpandEnv(baseCfg.SSLCertsDir) + "\n\n" + + "Do not forget to generate a client certificate for the 'root' user!\n" + + "This must be done manually, preferably from a different unix user account\n" + + "than the one running the server. Eample command:\n\n" + + " " + os.Args[0] + " cert create-client root --ca-key=\n") + } }() - return server.InitHandshake(ctx, baseCfg, baseCfg.InitToken, len(peers), peers, baseCfg.SSLCertsDir, listener) + reporter := func(format string, args ...interface{}) { + fmt.Printf(format+"\n", args...) + } + + return server.InitHandshake(ctx, reporter, baseCfg, startCtx.initToken, startCtx.numExpectedNodes, peers, baseCfg.SSLCertsDir, rpcLn) +} + +func validateConnectFlags(cmd *cobra.Command, requireExplicitFlags bool) error { + if requireExplicitFlags { + f := flagSetForCmd(cmd) + if !(f.Lookup(cliflags.SingleNode.Name).Changed || + (f.Lookup(cliflags.NumExpectedInitialNodes.Name).Changed && f.Lookup(cliflags.InitToken.Name).Changed)) { + return errors.Newf("either --%s must be passed, or both --%s and --%s", + cliflags.SingleNode.Name, cliflags.NumExpectedInitialNodes.Name, cliflags.InitToken.Name) + } + if f.Lookup(cliflags.SingleNode.Name).Changed && + (f.Lookup(cliflags.NumExpectedInitialNodes.Name).Changed || f.Lookup(cliflags.InitToken.Name).Changed) { + return errors.Newf("--%s cannot be specified together with --%s or --%s", + cliflags.SingleNode.Name, cliflags.NumExpectedInitialNodes.Name, cliflags.InitToken.Name) + } + } + + if startCtx.genCertsForSingleNode { + startCtx.numExpectedNodes = 1 + startCtx.initToken = "start-single-node" + return nil + } + + if startCtx.numExpectedNodes < 1 { + return errors.Newf("flag --%s must be set to a value greater than or equal to 1", + cliflags.NumExpectedInitialNodes.Name) + } + if startCtx.initToken == "" { + return errors.Newf("flag --%s must be set to a non-empty string", + cliflags.InitToken.Name) + } + return nil } diff --git a/pkg/cli/context.go b/pkg/cli/context.go index a4066d23d0d8..35295812a91a 100644 --- a/pkg/cli/context.go +++ b/pkg/cli/context.go @@ -413,6 +413,11 @@ var startCtx struct { serverCertPrincipalMap []string serverListenAddr string + // The TLS auto-handshake parameters. + initToken string + numExpectedNodes int + genCertsForSingleNode bool + // if specified, this forces the HTTP listen addr to localhost // and disables TLS on the HTTP listener. unencryptedLocalhostHTTP bool @@ -448,6 +453,9 @@ func setStartContextDefaults() { startCtx.serverSSLCertsDir = base.DefaultCertsDirectory startCtx.serverCertPrincipalMap = nil startCtx.serverListenAddr = "" + startCtx.initToken = "" + startCtx.numExpectedNodes = 0 + startCtx.genCertsForSingleNode = false startCtx.unencryptedLocalhostHTTP = false startCtx.tempDir = "" startCtx.externalIODir = "" diff --git a/pkg/cli/flags.go b/pkg/cli/flags.go index 74de25689556..2c7b0b0f71c8 100644 --- a/pkg/cli/flags.go +++ b/pkg/cli/flags.go @@ -255,7 +255,9 @@ func init() { // Add a pre-run command for `start` and `start-single-node`, as well as the // multi-tenancy related commands that start long-running servers. - for _, cmd := range serverCmds { + // Also for `connect` which does not really start a server but uses + // all the networking flags. + for _, cmd := range append(serverCmds, connectCmd) { AddPersistentPreRunE(cmd, func(cmd *cobra.Command, _ []string) error { // Finalize the configuration of network settings. return extraServerFlagInit(cmd) @@ -334,21 +336,44 @@ func init() { // avoid printing some messages to standard output in that case. _, startCtx.inBackground = envutil.EnvString(backgroundEnvVar, 1) - for _, cmd := range StartCmds { + // Flags common to the start commands and the connect command. + for _, cmd := range append(StartCmds, connectCmd) { f := cmd.Flags() - // Server flags. varFlag(f, addrSetter{&startCtx.serverListenAddr, &serverListenPort}, cliflags.ListenAddr) varFlag(f, addrSetter{&serverAdvertiseAddr, &serverAdvertisePort}, cliflags.AdvertiseAddr) varFlag(f, addrSetter{&serverSQLAddr, &serverSQLPort}, cliflags.ListenSQLAddr) varFlag(f, addrSetter{&serverSQLAdvertiseAddr, &serverSQLAdvertisePort}, cliflags.SQLAdvertiseAddr) varFlag(f, addrSetter{&serverHTTPAddr, &serverHTTPPort}, cliflags.ListenHTTPAddr) - stringFlag(f, &serverSocketDir, cliflags.SocketDir) - boolFlag(f, &startCtx.unencryptedLocalhostHTTP, cliflags.UnencryptedLocalhostHTTP) - // The following flag is planned to become non-experimental in 21.1. - boolFlag(f, &serverCfg.AcceptSQLWithoutTLS, cliflags.AcceptSQLWithoutTLS) - _ = f.MarkHidden(cliflags.AcceptSQLWithoutTLS.Name) + // Certificates directory. Use a server-specific flag and value to ignore environment + // variables, but share the same default. + stringFlag(f, &startCtx.serverSSLCertsDir, cliflags.ServerCertsDir) + + // Cluster joining flags. We need to enable this both for 'start' + // and 'start-single-node' although the latter does not support + // --join, because it delegates its logic to that of 'start', and + // 'start' will check that the flag is properly defined. + varFlag(f, &serverCfg.JoinList, cliflags.Join) + boolFlag(f, &serverCfg.JoinPreferSRVRecords, cliflags.JoinPreferSRVRecords) + + // The initialization token and expected peers. For 'start' commands this is optional. + stringFlag(f, &startCtx.initToken, cliflags.InitToken) + intFlag(f, &startCtx.numExpectedNodes, cliflags.NumExpectedInitialNodes) + boolFlag(f, &startCtx.genCertsForSingleNode, cliflags.SingleNode) + + if cmd == startSingleNodeCmd { + // Even though all server flags are supported for + // 'start-single-node', we intend that command to be used by + // beginners / developers running on a single machine. To + // enhance the UX, we hide the flags since they are not directly + // relevant when running a single node. + _ = f.MarkHidden(cliflags.Join.Name) + _ = f.MarkHidden(cliflags.JoinPreferSRVRecords.Name) + _ = f.MarkHidden(cliflags.AdvertiseAddr.Name) + _ = f.MarkHidden(cliflags.SQLAdvertiseAddr.Name) + _ = f.MarkHidden(cliflags.InitToken.Name) + } // Backward-compatibility flags. @@ -372,6 +397,20 @@ func init() { varFlag(f, aliasStrVar{&serverHTTPPort}, cliflags.ListenHTTPPort) _ = f.MarkHidden(cliflags.ListenHTTPPort.Name) + } + + // Flags common to the start commands only. + for _, cmd := range StartCmds { + f := cmd.Flags() + + // Server flags. + stringFlag(f, &serverSocketDir, cliflags.SocketDir) + boolFlag(f, &startCtx.unencryptedLocalhostHTTP, cliflags.UnencryptedLocalhostHTTP) + + // The following flag is planned to become non-experimental in 21.1. + boolFlag(f, &serverCfg.AcceptSQLWithoutTLS, cliflags.AcceptSQLWithoutTLS) + _ = f.MarkHidden(cliflags.AcceptSQLWithoutTLS.Name) + // More server flags. varFlag(f, &localityAdvertiseHosts, cliflags.LocalityAdvertiseAddr) @@ -400,19 +439,10 @@ func init() { boolFlag(f, &serverCfg.ExternalIODirConfig.DisableOutbound, cliflags.ExternalIODisabled) boolFlag(f, &serverCfg.ExternalIODirConfig.DisableImplicitCredentials, cliflags.ExternalIODisableImplicitCredentials) - // Certificates directory. Use a server-specific flag and value to ignore environment - // variables, but share the same default. - stringFlag(f, &startCtx.serverSSLCertsDir, cliflags.ServerCertsDir) - // Certificate principal map. stringSliceFlag(f, &startCtx.serverCertPrincipalMap, cliflags.CertPrincipalMap) - // Cluster joining flags. We need to enable this both for 'start' - // and 'start-single-node' although the latter does not support - // --join, because it delegates its logic to that of 'start', and - // 'start' will check that the flag is properly defined. - varFlag(f, &serverCfg.JoinList, cliflags.Join) - boolFlag(f, &serverCfg.JoinPreferSRVRecords, cliflags.JoinPreferSRVRecords) + // Cluster name verification. varFlag(f, clusterNameSetter{&baseCfg.ClusterName}, cliflags.ClusterName) boolFlag(f, &baseCfg.DisableClusterNameVerification, cliflags.DisableClusterNameVerification) if cmd == startSingleNodeCmd { @@ -421,13 +451,10 @@ func init() { // beginners / developers running on a single machine. To // enhance the UX, we hide the flags since they are not directly // relevant when running a single node. - _ = f.MarkHidden(cliflags.Join.Name) _ = f.MarkHidden(cliflags.ClusterName.Name) _ = f.MarkHidden(cliflags.DisableClusterNameVerification.Name) _ = f.MarkHidden(cliflags.MaxOffset.Name) _ = f.MarkHidden(cliflags.LocalityAdvertiseAddr.Name) - _ = f.MarkHidden(cliflags.AdvertiseAddr.Name) - _ = f.MarkHidden(cliflags.SQLAdvertiseAddr.Name) } // Engine flags. @@ -473,15 +500,6 @@ func init() { stringSliceFlag(f, &cliCtx.certPrincipalMap, cliflags.CertPrincipalMap) } - // Flags for the connect command. - { - f := connectCmd.Flags() - stringFlag(f, &baseCfg.SSLCertsDir, cliflags.CertsDir) - stringFlag(f, &baseCfg.InitToken, cliflags.InitToken) - varFlag(f, addrSetter{&startCtx.serverListenAddr, &serverListenPort}, cliflags.ListenAddr) - varFlag(f, &serverCfg.JoinList, cliflags.Join) - } - for _, cmd := range []*cobra.Command{ createCACertCmd, createClientCACertCmd, diff --git a/pkg/cli/interactive_tests/test_connect.tcl b/pkg/cli/interactive_tests/test_connect.tcl new file mode 100644 index 000000000000..7752e8137df5 --- /dev/null +++ b/pkg/cli/interactive_tests/test_connect.tcl @@ -0,0 +1,88 @@ +#! /usr/bin/env expect -f + +source [file join [file dirname $argv0] common.tcl] + +spawn /bin/bash +set shell1_spawn_id $spawn_id +send "PS1=':''/# '\r" +eexpect ":/# " + +set ::env(COCKROACH_INSECURE) "false" + +system "hostname >hostname.txt" + +start_test "Check that the connect command can generate single-node credentials" +# Run connect. We are careful to preserve the generated files into the logs sub-directory +# so that the artifacts remain for investigation if the command fail. +# The reason why we do not use --certs-dir=logs directly is that the log directory +# makes its contents world-readable, and crdb asserts that cert / key files +# are not world-readable. +send "$argv connect --single-node --listen-addr=`cat hostname.txt` --http-addr=`cat hostname.txt` --certs-dir=certs/sn; cp -a certs logs/\r" +eexpect "generating cert bundle" +eexpect "cert files generated" +eexpect ":/# " +end_test + +start_test "Check that we can start a secure server with that" +system "$argv start-single-node --listen-addr=`cat hostname.txt` --certs-dir=certs/sn --pid-file=server_pid -s=path=logs/db --background >>logs/expect-cmd.log 2>&1" +end_test + +# NB: we will be able to remove the manual generation of root certs +# some time in the future. +system "$argv cert create-client root --ca-key=certs/sn/ca-client.key --certs-dir=certs/sn" + +start_test "Check we can connect a SQL client with that" +system "$argv sql --certs-dir=certs/sn --host=`cat hostname.txt` -e 'select 1'" +end_test + +# Stop the server we started above. +stop_server $argv + +spawn /bin/bash +set shell2_spawn_id $spawn_id +send "PS1=':''/# '\r" +eexpect ":/# " + +system "mkdir -p logs/n1 logs/n2" + +start_test "Check that the connect command can generate certs for two nodes." +set spawn_id $shell1_spawn_id +send "$argv connect --num-expected-initial-nodes 2 --init-token=abc --listen-addr=`cat hostname.txt`:26257 --http-addr=`cat hostname.txt`:8080 --join=`cat hostname.txt`:26258 --certs-dir=certs/n1 --log='file-defaults: {dir: logs/n1}\r" +send "sinks: {stderr: {filter: NONE}}'\r" +eexpect "waiting for handshake" + +set spawn_id $shell2_spawn_id +send "$argv connect --num-expected-initial-nodes 2 --init-token=abc --listen-addr=`cat hostname.txt`:26258 --http-addr=`cat hostname.txt`:8081 --join=`cat hostname.txt`:26257 --certs-dir=certs/n2 --log='file-defaults: {dir: logs/n2}\r" +send "sinks: {stderr: {filter: NONE}}'\r" +eexpect "waiting for handshake" +eexpect "trusted peer" +eexpect "cert bundle" +eexpect "cert files generated in: certs/n2" +eexpect ":/# " + +set spawn_id $shell1_spawn_id +eexpect "trusted peer" +eexpect "cert bundle" +eexpect "cert files generated in: certs/n1" +eexpect ":/# " +end_test + +system "cp -a certs logs/" + +# NB: we will be able to remove the manual generation of root certs +# some time in the future. +system "$argv cert create-client root --ca-key=certs/n1/ca-client.key --certs-dir=certs/n1" +system "$argv cert create-client root --ca-key=certs/n2/ca-client.key --certs-dir=certs/n2" + +# TODO(knz): Also test multi-server start once the advertise addresses are populated. +# +# start_test "Check that we can start two servers using the newly minted certs." +# send "$argv start --listen-addr=`cat hostname.txt`:26257 --http-addr=`cat hostname.txt`:8080 --join=`cat hostname.txt`:26258 --certs-dir=certs/n1 --store=logs/db1 --vmodule='*=1'\r" +# eexpect "initial startup completed" +# +# set spawn_id $shell2_spawn_id +# send "$argv start --listen-addr=`cat hostname.txt`:26258 --http-addr=`cat hostname.txt`:8081 --join=`cat hostname.txt`:26257 --certs-dir=certs/n2 --store=logs/db2 --vmodule='*=1'\r" +# eexpect "initial startup completed" + +end_test + diff --git a/pkg/geo/geomfn/affine_transforms_test.go b/pkg/geo/geomfn/affine_transforms_test.go index 1de0b688a444..e72a4bcf7dc0 100644 --- a/pkg/geo/geomfn/affine_transforms_test.go +++ b/pkg/geo/geomfn/affine_transforms_test.go @@ -280,10 +280,10 @@ func TestCollectionScaleRelativeToOrigin(t *testing.T) { }, { desc: "scale empty collection", - input: geom.NewGeometryCollection(), + input: geom.NewGeometryCollection().MustSetLayout(geom.XY), factor: geom.NewPointFlat(geom.XY, []float64{2, 2}), origin: geom.NewPointFlat(geom.XY, []float64{1, 1}), - expected: geom.NewGeometryCollection(), + expected: geom.NewGeometryCollection().MustSetLayout(geom.XY), }, } diff --git a/pkg/geo/geomfn/flip_coordinates_test.go b/pkg/geo/geomfn/flip_coordinates_test.go index c30e10f27d81..3e0e691a7731 100644 --- a/pkg/geo/geomfn/flip_coordinates_test.go +++ b/pkg/geo/geomfn/flip_coordinates_test.go @@ -110,8 +110,8 @@ func TestFlipCoordinates(t *testing.T) { }, { desc: "flip coordinates of an empty collection", - input: geom.NewGeometryCollection(), - expected: geom.NewGeometryCollection(), + input: geom.NewGeometryCollection().MustSetLayout(geom.XY), + expected: geom.NewGeometryCollection().MustSetLayout(geom.XY), }, } diff --git a/pkg/geo/wkt/BUILD.bazel b/pkg/geo/wkt/BUILD.bazel index 2674f6d6a216..24697e85b3d0 100644 --- a/pkg/geo/wkt/BUILD.bazel +++ b/pkg/geo/wkt/BUILD.bazel @@ -5,7 +5,7 @@ go_library( srcs = [ "lex.go", "wkt.go", - "wkt_generated.go", + ":wkt-generated", # keep ], importpath = "github.com/cockroachdb/cockroach/pkg/geo/wkt", visibility = ["//visibility:public"], @@ -22,3 +22,20 @@ go_test( "@com_github_twpayne_go_geom//:go-geom", ], ) + +# Based on pkg/geo/wkt/generate.sh file +genrule( + name = "wkt-generated", + srcs = [ + "wkt.y", + ], + outs = ["wkt_generated.go"], + cmd = """ + $(location @org_golang_x_tools//cmd/goyacc) -o $(location wkt_generated.go) -p "wkt" $(location wkt.y) + cat $(location wkt_generated.go) | sed -e 's/wktErrorVerbose = false/wktErrorVerbose = true/' > wkt_generated.go.tmp + mv wkt_generated.go.tmp $(location wkt_generated.go) + """, + tools = [ + "@org_golang_x_tools//cmd/goyacc", + ], +) diff --git a/pkg/geo/wkt/generate.sh b/pkg/geo/wkt/generate.sh old mode 100644 new mode 100755 diff --git a/pkg/security/auto_tls_init.go b/pkg/security/auto_tls_init.go index 68a9ca6588e5..e56fec652ba2 100644 --- a/pkg/security/auto_tls_init.go +++ b/pkg/security/auto_tls_init.go @@ -12,6 +12,7 @@ package security import ( "bytes" + "context" "crypto/rand" "crypto/rsa" "crypto/x509" @@ -23,6 +24,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/util/timeutil" "github.com/cockroachdb/errors" + "github.com/cockroachdb/redact" ) // TODO(aaron-crl): This shared a name and purpose with the value in @@ -53,11 +55,39 @@ func createCertificateSerialNumber() (serialNumber *big.Int, err error) { return } +// LoggerFn is the type we use to inject logging functions into the +// security package to avoid circular dependencies. +type LoggerFn = func(ctx context.Context, format string, args ...interface{}) + +func describeCert(cert *x509.Certificate) redact.RedactableString { + var buf redact.StringBuilder + buf.SafeString("{\n") + buf.Printf(" SN: %s,\n", cert.SerialNumber) + buf.Printf(" CA: %v,\n", cert.IsCA) + buf.Printf(" Issuer: %q,\n", cert.Issuer) + buf.Printf(" Subject: %q,\n", cert.Subject) + buf.Printf(" NotBefore: %s,\n", cert.NotBefore) + buf.Printf(" NotAfter: %s", cert.NotAfter) + buf.Printf(" (Validity: %s),\n", cert.NotAfter.Sub(timeutil.Now())) + if !cert.IsCA { + buf.Printf(" DNS: %v,\n", cert.DNSNames) + buf.Printf(" IP: %v\n", cert.IPAddresses) + } + buf.SafeString("}") + return buf.RedactableString() +} + +const ( + crlOrg = "Cockroach Labs" + crlIssuerOU = "automatic cert generator" + crlC = "US" +) + // CreateCACertAndKey will create a CA with a validity beginning // now() and expiring after `lifespan`. This is a utility function to help // with cluster auto certificate generation. func CreateCACertAndKey( - lifespan time.Duration, service string, + ctx context.Context, loggerFn LoggerFn, lifespan time.Duration, service string, ) (certPEM []byte, keyPEM []byte, err error) { notBefore := timeutil.Now().Add(-notBeforeMargin) notAfter := timeutil.Now().Add(lifespan) @@ -71,10 +101,15 @@ func CreateCACertAndKey( // Create short lived initial CA template. ca := &x509.Certificate{ SerialNumber: serialNumber, + Issuer: pkix.Name{ + Organization: []string{crlOrg}, + OrganizationalUnit: []string{crlIssuerOU}, + Country: []string{crlC}, + }, Subject: pkix.Name{ - Organization: []string{"Cockroach Labs"}, + Organization: []string{crlOrg}, OrganizationalUnit: []string{service}, - Country: []string{"US"}, + Country: []string{crlC}, }, NotBefore: notBefore, NotAfter: notAfter, @@ -83,6 +118,9 @@ func CreateCACertAndKey( BasicConstraintsValid: true, MaxPathLenZero: true, } + if loggerFn != nil { + loggerFn(ctx, "creating CA cert from template: %s", describeCert(ca)) + } // Create private and public key for CA. caPrivKey, err := rsa.GenerateKey(rand.Reader, defaultKeySize) @@ -104,6 +142,9 @@ func CreateCACertAndKey( return nil, nil, err } + if loggerFn != nil { + loggerFn(ctx, "signing CA cert") + } // Create CA certificate then PEM encode it. caBytes, err := x509.CreateCertificate(rand.Reader, ca, ca, &caPrivKey.PublicKey, caPrivKey) if err != nil { @@ -128,7 +169,14 @@ func CreateCACertAndKey( // CreateServiceCertAndKey creates a cert/key pair signed by the provided CA. // This is a utility function to help with cluster auto certificate generation. func CreateServiceCertAndKey( - lifespan time.Duration, service string, hostnames []string, caCertPEM []byte, caKeyPEM []byte, + ctx context.Context, + loggerFn LoggerFn, + lifespan time.Duration, + commonName, service string, + hostnames []string, + caCertPEM []byte, + caKeyPEM []byte, + serviceCertIsAlsoValidAsClient bool, ) (certPEM []byte, keyPEM []byte, err error) { notBefore := timeutil.Now().Add(-notBeforeMargin) notAfter := timeutil.Now().Add(lifespan) @@ -169,10 +217,16 @@ func CreateServiceCertAndKey( // pkg/security/x509.go until we can consolidate them. serviceCert := &x509.Certificate{ SerialNumber: serialNumber, + Issuer: pkix.Name{ + Organization: []string{crlOrg}, + OrganizationalUnit: []string{crlIssuerOU}, + Country: []string{crlC}, + }, Subject: pkix.Name{ - Organization: []string{"Cockroach Labs"}, + Organization: []string{crlOrg}, OrganizationalUnit: []string{service}, - Country: []string{"US"}, + Country: []string{crlC}, + CommonName: commonName, }, NotBefore: notBefore, NotAfter: notAfter, @@ -180,6 +234,10 @@ func CreateServiceCertAndKey( KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, } + if serviceCertIsAlsoValidAsClient { + serviceCert.ExtKeyUsage = append(serviceCert.ExtKeyUsage, x509.ExtKeyUsageClientAuth) + } + // Attempt to parse hostname as IP, if successful add it as an IP // otherwise presume it is a DNS name. // TODO(aaron-crl): Pass these values via config object. @@ -192,11 +250,18 @@ func CreateServiceCertAndKey( } } + if loggerFn != nil { + loggerFn(ctx, "creating service cert from template: %s", describeCert(serviceCert)) + } + servicePrivKey, err := rsa.GenerateKey(rand.Reader, defaultKeySize) if err != nil { return nil, nil, err } + if loggerFn != nil { + loggerFn(ctx, "signing service cert") + } serviceCertBytes, err := x509.CreateCertificate(rand.Reader, serviceCert, caCert, &servicePrivKey.PublicKey, caKey) if err != nil { return nil, nil, err diff --git a/pkg/security/auto_tls_init_test.go b/pkg/security/auto_tls_init_test.go index b66c67904e55..5cd4e1d5405b 100644 --- a/pkg/security/auto_tls_init_test.go +++ b/pkg/security/auto_tls_init_test.go @@ -11,6 +11,7 @@ package security_test import ( + "context" "testing" "time" @@ -22,7 +23,8 @@ import ( // TODO(aaron-crl): [tests] write unit tests func TestDummyCreateCACertAndKey(t *testing.T) { defer leaktest.AfterTest(t)() - _, _, err := security.CreateCACertAndKey(time.Hour, "test CA cert generation") + _, _, err := security.CreateCACertAndKey(context.Background(), nil, /* loggerFn */ + time.Hour, "test CA cert generation") if err != nil { t.Fatalf("expected err=nil, got: %s", err) } @@ -32,17 +34,21 @@ func TestDummyCreateCACertAndKey(t *testing.T) { // TODO(aaron-crl): [tests] write unit tests func TestDummyCreateServiceCertAndKey(t *testing.T) { defer leaktest.AfterTest(t)() - caCert, caKey, err := security.CreateCACertAndKey(time.Hour, "test CA cert generation") + caCert, caKey, err := security.CreateCACertAndKey(context.Background(), nil, /* loggerFn */ + time.Hour, "test CA cert generation") if err != nil { t.Fatalf("expected err=nil, got: %s", err) } _, _, err = security.CreateServiceCertAndKey( + context.Background(), nil, /* loggerFn */ time.Minute, + "dummy-common-name", "test Service cert generation", []string{"localhost", "127.0.0.1"}, caCert, caKey, + false, /* serviceCertIsAlsoValidAsClient */ ) if err != nil { t.Fatalf("expected err=nil, got: %s", err) diff --git a/pkg/security/certificate_manager.go b/pkg/security/certificate_manager.go index 3985f8f0375c..d0d804d55389 100644 --- a/pkg/security/certificate_manager.go +++ b/pkg/security/certificate_manager.go @@ -25,6 +25,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/util/syncutil" "github.com/cockroachdb/cockroach/pkg/util/sysutil" "github.com/cockroachdb/errors" + "github.com/cockroachdb/errors/oserror" ) var ( @@ -262,6 +263,12 @@ func (cl CertsLocator) CACertPath() string { return filepath.Join(cl.certsDir, CACertFilename()) } +// EnsureCertsDirectory ensures that the certs directory exists by +// creating it if does not exist yet. +func (cl CertsLocator) EnsureCertsDirectory() error { + return os.MkdirAll(cl.certsDir, 0700) +} + // CACertFilename returns the expected file name for the CA certificate. func CACertFilename() string { return "ca" + certExtension } @@ -314,6 +321,18 @@ func (cl CertsLocator) NodeCertPath() string { return filepath.Join(cl.certsDir, NodeCertFilename()) } +// HasNodeCert returns true iff the node certificate file already exists. +func (cl CertsLocator) HasNodeCert() (bool, error) { + _, err := os.Stat(cl.NodeCertPath()) + if err != nil { + if oserror.IsNotExist(err) { + return false, nil + } + return false, err + } + return true, nil +} + // NodeCertFilename returns the expected file name for the node certificate. func NodeCertFilename() string { return "node" + certExtension diff --git a/pkg/server/BUILD.bazel b/pkg/server/BUILD.bazel index 005a68d47e39..2ac143451c10 100644 --- a/pkg/server/BUILD.bazel +++ b/pkg/server/BUILD.bazel @@ -339,6 +339,7 @@ go_test( "//pkg/util/uuid", "@com_github_cockroachdb_datadriven//:datadriven", "@com_github_cockroachdb_errors//:errors", + "@com_github_cockroachdb_logtags//:logtags", "@com_github_dustin_go_humanize//:go-humanize", "@com_github_gogo_protobuf//jsonpb", "@com_github_gogo_protobuf//proto", diff --git a/pkg/server/auto_tls_init.go b/pkg/server/auto_tls_init.go index 748a359d4f43..4c3a7d4ccad8 100644 --- a/pkg/server/auto_tls_init.go +++ b/pkg/server/auto_tls_init.go @@ -17,6 +17,7 @@ package server import ( + "context" "encoding/pem" "io/ioutil" "os" @@ -24,8 +25,11 @@ import ( "github.com/cockroachdb/cockroach/pkg/base" "github.com/cockroachdb/cockroach/pkg/security" + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/cockroachdb/cockroach/pkg/util/netutil" "github.com/cockroachdb/errors" "github.com/cockroachdb/errors/oserror" + "github.com/cockroachdb/logtags" ) // TODO(aaron-crl): This is an exact copy from `pkg/cli/cert.go` and should @@ -36,11 +40,11 @@ const defaultCALifetime = 10 * 366 * 24 * time.Hour // ten years const defaultCertLifetime = 5 * 366 * 24 * time.Hour // five years // Service Name Strings for autogenerated certificates. -const serviceNameInterNode = "InterNode Service" -const serviceNameUserAuth = "User Auth Service" -const serviceNameSQL = "SQL Service" -const serviceNameRPC = "RPC Service" -const serviceNameUI = "UI Service" +const serviceNameInterNode = "cockroach-node" +const serviceNameUserAuth = "cockroach-client" +const serviceNameSQL = "cockroach-sql" +const serviceNameRPC = "cockroach-rpc" +const serviceNameUI = "cockroach-http" // CertificateBundle manages the collection of certificates used by a // CockroachDB node. @@ -90,14 +94,20 @@ func (sb *ServiceCertificateBundle) loadCACertAndKey(certPath string, keyPath st // LoadUserAuthCACertAndKey loads host certificate and key from disk or fails with error. func (sb *ServiceCertificateBundle) loadOrCreateUserAuthCACertAndKey( - caCertPath string, caKeyPath string, initLifespan time.Duration, serviceName string, + ctx context.Context, + caCertPath string, + caKeyPath string, + initLifespan time.Duration, + serviceName string, ) (err error) { + log.Ops.Infof(ctx, "attempting to load CA cert: %s", caCertPath) // Attempt to load cert into ServiceCertificateBundle. sb.CACertificate, err = loadCertificateFile(caCertPath) if err != nil { if oserror.IsNotExist(err) { + log.Ops.Infof(ctx, "not found; auto-generating") // Certificate not found, attempt to create both cert and key now. - err = sb.createServiceCA(caCertPath, caKeyPath, initLifespan, serviceName) + err = sb.createServiceCA(ctx, caCertPath, caKeyPath, initLifespan, serviceName) if err != nil { return err } @@ -110,12 +120,16 @@ func (sb *ServiceCertificateBundle) loadOrCreateUserAuthCACertAndKey( return err } + log.Ops.Infof(ctx, "found; loading CA key: %s", caKeyPath) // Load the key only if it exists. sb.CAKey, err = loadKeyFile(caKeyPath) - if !oserror.IsNotExist(err) { - // An error returned but it was not that the file didn't exist; - // this is an error. - return err + if err != nil { + if !oserror.IsNotExist(err) { + // An error returned but it was not that the file didn't exist; + // this is an error. + return err + } + log.Ops.Infof(ctx, "CA key not found") } return nil @@ -133,20 +147,26 @@ func (sb *ServiceCertificateBundle) loadOrCreateUserAuthCACertAndKey( // It will persist these to disk and store them // in the ServiceCertificateBundle. func (sb *ServiceCertificateBundle) loadOrCreateServiceCertificates( + ctx context.Context, serviceCertPath string, serviceKeyPath string, caCertPath string, caKeyPath string, serviceCertLifespan time.Duration, caCertLifespan time.Duration, + commonName string, serviceName string, hostnames []string, + serviceCertIsAlsoValidAsClient bool, ) error { - var err error + ctx = logtags.AddTag(ctx, "service", serviceName) + var err error + log.Ops.Infof(ctx, "attempting to load service cert: %s", serviceCertPath) // Check if the service cert and key already exist, if it does return early. sb.HostCertificate, err = loadCertificateFile(serviceCertPath) if err == nil { + log.Ops.Infof(ctx, "found; loading service key: %s", serviceKeyPath) // Cert file exists, now load key. sb.HostKey, err = loadKeyFile(serviceKeyPath) if err != nil { @@ -160,14 +180,20 @@ func (sb *ServiceCertificateBundle) loadOrCreateServiceCertificates( return errors.Wrap(err, "something went wrong loading service key") } // Both certificate and key should be successfully loaded. + log.Ops.Infof(ctx, "service cert is ready") return nil } + // TODO(aaron-crl, knz): err != nil is not handled here. + + log.Ops.Infof(ctx, "not found; will attempt auto-creation") - // Niether service cert or key exist, attempt to load CA. + log.Ops.Infof(ctx, "attempting to load CA cert: %s", caCertPath) + // Neither service cert or key exist, attempt to load CA. sb.CACertificate, err = loadCertificateFile(caCertPath) if err == nil { // CA cert has been successfully loaded, attempt to load // CA key. + log.Ops.Infof(ctx, "found; loading CA key: %s", caKeyPath) sb.CAKey, err = loadKeyFile(caKeyPath) if err != nil { return errors.Wrapf( @@ -175,23 +201,27 @@ func (sb *ServiceCertificateBundle) loadOrCreateServiceCertificates( ) } } else if oserror.IsNotExist(err) { + log.Ops.Infof(ctx, "not found; CA cert does not exist, auto-creating") // CA cert does not yet exist, create it and its key. - err = sb.createServiceCA(caCertPath, caKeyPath, caCertLifespan, serviceName) - if err != nil { + if err := sb.createServiceCA(ctx, caCertPath, caKeyPath, caCertLifespan, serviceName); err != nil { return errors.Wrap( err, "failed to create Service CA", ) } } + // TODO(aaron-crl, knz): missing `else` case here. // CA cert and key should now be loaded, create service cert and key. - //var hostCert, hostKey []byte sb.HostCertificate, sb.HostKey, err = security.CreateServiceCertAndKey( + ctx, + log.Ops.Infof, serviceCertLifespan, + commonName, serviceName, hostnames, sb.CACertificate, sb.CAKey, + serviceCertIsAlsoValidAsClient, ) if err != nil { return errors.Wrap( @@ -199,13 +229,13 @@ func (sb *ServiceCertificateBundle) loadOrCreateServiceCertificates( ) } - err = writeCertificateFile(serviceCertPath, sb.HostCertificate, false) - if err != nil { + log.Ops.Infof(ctx, "writing service cert: %s", serviceCertPath) + if err := writeCertificateFile(serviceCertPath, sb.HostCertificate, false); err != nil { return err } - err = writeKeyFile(serviceKeyPath, sb.HostKey, false) - if err != nil { + log.Ops.Infof(ctx, "writing service key: %s", serviceKeyPath) + if err := writeKeyFile(serviceKeyPath, sb.HostKey, false); err != nil { return err } @@ -215,24 +245,31 @@ func (sb *ServiceCertificateBundle) loadOrCreateServiceCertificates( // createServiceCA builds CA cert and key and populates them to // ServiceCertificateBundle. func (sb *ServiceCertificateBundle) createServiceCA( - caCertPath string, caKeyPath string, initLifespan time.Duration, serviceName string, -) (err error) { - sb.CACertificate, sb.CAKey, err = security.CreateCACertAndKey(initLifespan, serviceName) + ctx context.Context, + caCertPath string, + caKeyPath string, + initLifespan time.Duration, + serviceName string, +) error { + ctx = logtags.AddTag(ctx, "auto-create-ca", nil) + + var err error + sb.CACertificate, sb.CAKey, err = security.CreateCACertAndKey(ctx, log.Ops.Infof, initLifespan, serviceName) if err != nil { - return + return err } - err = writeCertificateFile(caCertPath, sb.CACertificate, false) - if err != nil { - return + log.Ops.Infof(ctx, "writing CA cert: %s", caCertPath) + if err := writeCertificateFile(caCertPath, sb.CACertificate, false); err != nil { + return err } - err = writeKeyFile(caKeyPath, sb.CAKey, false) - if err != nil { - return + log.Ops.Infof(ctx, "writing CA key: %s", caKeyPath) + if err := writeKeyFile(caKeyPath, sb.CAKey, false); err != nil { + return err } - return + return nil } // Simple wrapper to make it easier to store certs somewhere else later. @@ -292,95 +329,113 @@ func writeKeyFile(keyFilePath string, keyPEMBytes []byte, overwrite bool) error // cluster. It uses or generates an InterNode CA to produce any missing // unmanaged certificates. It does this base on the logic in: // https://github.com/cockroachdb/cockroach/pull/51991 -// N.B.: This function fast fails if an InterNodeHost cert/key pair are present +// N.B.: This function fast fails if an inter-node cert/key pair are present // as this should _never_ happen. -func (b *CertificateBundle) InitializeFromConfig(c base.Config) error { +func (b *CertificateBundle) InitializeFromConfig(ctx context.Context, c base.Config) error { cl := security.MakeCertsLocator(c.SSLCertsDir) // First check to see if host cert is already present // if it is, we should fail to initialize. - if _, err := os.Stat(cl.NodeCertPath()); err == nil { - return errors.New( - "interNodeHost certificate already present") - } else if !oserror.IsNotExist(err) { - return errors.Wrap( - err, "interNodeHost certificate access issue") + if exists, err := cl.HasNodeCert(); err != nil { + return err + } else if exists { + return errors.New("inter-node certificate already present") + } + + rpcAddrs := extractHosts(c.Addr, c.AdvertiseAddr) + sqlAddrs := rpcAddrs + if c.SplitListenSQL { + sqlAddrs = extractHosts(c.SQLAddr, c.SQLAdvertiseAddr) + } + httpAddrs := extractHosts(c.HTTPAddr, c.HTTPAdvertiseAddr) + + // Create the target directory if it does not exist yet. + if err := cl.EnsureCertsDirectory(); err != nil { + return err } // Start by loading or creating the InterNode certificates. - err := b.InterNode.loadOrCreateServiceCertificates( + if err := b.InterNode.loadOrCreateServiceCertificates( + ctx, cl.NodeCertPath(), cl.NodeKeyPath(), cl.CACertPath(), cl.CAKeyPath(), defaultCertLifetime, defaultCALifetime, + security.NodeUser, serviceNameInterNode, - []string{c.Addr, c.AdvertiseAddr}, - ) - if err != nil { + rpcAddrs, + true, /* serviceCertIsAlsoValidAsClient */ + ); err != nil { return errors.Wrap(err, "failed to load or create InterNode certificates") } // Initialize User auth certificates. - err = b.UserAuth.loadOrCreateUserAuthCACertAndKey( + if err := b.UserAuth.loadOrCreateUserAuthCACertAndKey( + ctx, cl.ClientCACertPath(), cl.ClientCAKeyPath(), defaultCALifetime, serviceNameUserAuth, - ) - if err != nil { + ); err != nil { return errors.Wrap(err, "failed to load or create User auth certificate(s)") } // Initialize SQLService Certs. - err = b.SQLService.loadOrCreateServiceCertificates( + if err := b.SQLService.loadOrCreateServiceCertificates( + ctx, cl.SQLServiceCertPath(), cl.SQLServiceKeyPath(), cl.SQLServiceCACertPath(), cl.SQLServiceCAKeyPath(), defaultCertLifetime, defaultCALifetime, + security.NodeUser, serviceNameSQL, // TODO(aaron-crl): Add RPC variable to config or SplitSQLAddr. - []string{c.SQLAddr, c.SQLAdvertiseAddr}, - ) - if err != nil { + sqlAddrs, + false, /* serviceCertIsAlsoValidAsClient */ + ); err != nil { return errors.Wrap(err, "failed to load or create SQL service certificate(s)") } // Initialize RPCService Certs. - err = b.RPCService.loadOrCreateServiceCertificates( + if err := b.RPCService.loadOrCreateServiceCertificates( + ctx, cl.RPCServiceCertPath(), cl.RPCServiceKeyPath(), cl.RPCServiceCACertPath(), cl.RPCServiceCAKeyPath(), defaultCertLifetime, defaultCALifetime, + security.NodeUser, serviceNameRPC, // TODO(aaron-crl): Add RPC variable to config. - []string{c.SQLAddr, c.SQLAdvertiseAddr}, - ) - if err != nil { + rpcAddrs, + false, /* serviceCertIsAlsoValidAsClient */ + ); err != nil { return errors.Wrap(err, "failed to load or create RPC service certificate(s)") } // Initialize AdminUIService Certs. - err = b.AdminUIService.loadOrCreateServiceCertificates( + if err := b.AdminUIService.loadOrCreateServiceCertificates( + ctx, cl.UICertPath(), cl.UIKeyPath(), cl.UICACertPath(), cl.UICAKeyPath(), defaultCertLifetime, defaultCALifetime, + httpAddrs[0], serviceNameUI, - []string{c.HTTPAddr, c.HTTPAdvertiseAddr}, - ) - if err != nil { + httpAddrs, + false, /* serviceCertIsAlsoValidAsClient */ + ); err != nil { return errors.Wrap(err, "failed to load or create Admin UI service certificate(s)") } @@ -388,20 +443,45 @@ func (b *CertificateBundle) InitializeFromConfig(c base.Config) error { return nil } +func extractHosts(addrs ...string) []string { + res := make([]string, 0, len(addrs)) + + for _, addr := range addrs { + hostname, _, err := netutil.SplitHostPort(addr, "0") + if err != nil { + panic(err) + } + found := false + for _, h := range res { + if h == hostname { + found = true + break + } + } + if !found { + res = append(res, hostname) + } + } + return res +} + // InitializeNodeFromBundle uses the contents of the CertificateBundle and // details from the config object to write certs to disk and generate any // missing host-specific certificates and keys // It is assumed that a node receiving this has not has TLS initialized. If -// a interNodeHost certificate is found, this function will error. -func (b *CertificateBundle) InitializeNodeFromBundle(c base.Config) error { +// an inter-node certificate is found, this function will error. +func (b *CertificateBundle) InitializeNodeFromBundle(ctx context.Context, c base.Config) error { cl := security.MakeCertsLocator(c.SSLCertsDir) // First check to see if host cert is already present // if it is, we should fail to initialize. - if _, err := os.Stat(cl.NodeCertPath()); err == nil { - return errors.New("interNodeHost certificate already present") - } else if !oserror.IsNotExist(err) { - // Something else went wrong accessing the path + if exists, err := cl.HasNodeCert(); err != nil { + return err + } else if exists { + return errors.New("inter-node certificate already present") + } + + if err := cl.EnsureCertsDirectory(); err != nil { return err } @@ -409,39 +489,33 @@ func (b *CertificateBundle) InitializeNodeFromBundle(c base.Config) error { // and return an error. // Attempt to write InterNodeHostCA to disk first. - err := b.InterNode.writeCAOrFail(cl.CACertPath(), cl.CAKeyPath()) - if err != nil { + if err := b.InterNode.writeCAOrFail(cl.CACertPath(), cl.CAKeyPath()); err != nil { return errors.Wrap(err, "failed to write InterNodeCA to disk") } // Attempt to write ClientCA to disk. - err = b.InterNode.writeCAOrFail(cl.ClientCACertPath(), cl.ClientCAKeyPath()) - if err != nil { + if err := b.InterNode.writeCAOrFail(cl.ClientCACertPath(), cl.ClientCAKeyPath()); err != nil { return errors.Wrap(err, "failed to write ClientCA to disk") } // Attempt to write SQLServiceCA to disk. - err = b.InterNode.writeCAOrFail(cl.SQLServiceCACertPath(), cl.SQLServiceCAKeyPath()) - if err != nil { + if err := b.InterNode.writeCAOrFail(cl.SQLServiceCACertPath(), cl.SQLServiceCAKeyPath()); err != nil { return errors.Wrap(err, "failed to write SQLServiceCA to disk") } // Attempt to write RPCServiceCA to disk. - err = b.InterNode.writeCAOrFail(cl.RPCServiceCACertPath(), cl.RPCServiceCAKeyPath()) - if err != nil { + if err := b.InterNode.writeCAOrFail(cl.RPCServiceCACertPath(), cl.RPCServiceCAKeyPath()); err != nil { return errors.Wrap(err, "failed to write RPCServiceCA to disk") } // Attempt to write AdminUIServiceCA to disk. - err = b.InterNode.writeCAOrFail(cl.UICACertPath(), cl.UICAKeyPath()) - if err != nil { + if err := b.InterNode.writeCAOrFail(cl.UICACertPath(), cl.UICAKeyPath()); err != nil { return errors.Wrap(err, "failed to write AdminUIServiceCA to disk") } // Once CAs are written call the same InitFromConfig function to create // host certificates. - err = b.InitializeFromConfig(c) - if err != nil { + if err := b.InitializeFromConfig(ctx, c); err != nil { return errors.Wrap( err, "failed to initialize host certs after writing CAs to disk") @@ -536,7 +610,7 @@ func collectLocalCABundle(c base.Config) (CertificateBundle, error) { // manually after rotation errors are corrected without negatively impacting // any interface. All existing interfaces will again receive a new // certificate/key pair. -func rotateGeneratedCerts(c base.Config) error { +func rotateGeneratedCerts(ctx context.Context, c base.Config) error { cl := security.MakeCertsLocator(c.SSLCertsDir) // Fail fast if we can't load the CAs. @@ -546,14 +620,24 @@ func rotateGeneratedCerts(c base.Config) error { err, "failed to load local CAs for certificate rotation") } + rpcAddrs := extractHosts(c.Addr, c.AdvertiseAddr) + sqlAddrs := rpcAddrs + if c.SplitListenSQL { + sqlAddrs = extractHosts(c.SQLAddr, c.SQLAdvertiseAddr) + } + httpAddrs := extractHosts(c.HTTPAddr, c.HTTPAdvertiseAddr) + // Rotate InterNode Certs. if b.InterNode.CACertificate != nil { err = b.InterNode.rotateServiceCert( + ctx, cl.NodeCertPath(), cl.NodeKeyPath(), defaultCertLifetime, + security.NodeUser, serviceNameInterNode, - []string{c.HTTPAddr, c.HTTPAdvertiseAddr}, + rpcAddrs, + true, /* serviceCertIsAlsoValidAsClient */ ) if err != nil { return errors.Wrap(err, "failed to rotate InterNode cert") @@ -565,11 +649,14 @@ func rotateGeneratedCerts(c base.Config) error { // Rotate SQLService Certs. if b.SQLService.CACertificate != nil { err = b.SQLService.rotateServiceCert( + ctx, cl.SQLServiceCertPath(), cl.SQLServiceKeyPath(), defaultCertLifetime, + security.NodeUser, serviceNameSQL, - []string{c.HTTPAddr, c.HTTPAdvertiseAddr}, + sqlAddrs, + false, /* serviceCertIsAlsoValidAsClient */ ) if err != nil { return errors.Wrap(err, "failed to rotate SQLService cert") @@ -579,11 +666,14 @@ func rotateGeneratedCerts(c base.Config) error { // Rotate RPCService Certs. if b.RPCService.CACertificate != nil { err = b.RPCService.rotateServiceCert( + ctx, cl.RPCServiceCertPath(), cl.RPCServiceKeyPath(), defaultCertLifetime, + security.NodeUser, serviceNameRPC, - []string{c.HTTPAddr, c.HTTPAdvertiseAddr}, + rpcAddrs, + false, /* serviceCertIsAlsoValidAsClient */ ) if err != nil { return errors.Wrap(err, "failed to rotate RPCService cert") @@ -593,11 +683,14 @@ func rotateGeneratedCerts(c base.Config) error { // Rotate AdminUIService Certs. if b.AdminUIService.CACertificate != nil { err = b.AdminUIService.rotateServiceCert( + ctx, cl.UICertPath(), cl.UIKeyPath(), defaultCertLifetime, + httpAddrs[0], serviceNameUI, - []string{c.HTTPAddr, c.HTTPAdvertiseAddr}, + httpAddrs, + false, /* serviceCertIsAlsoValidAsClient */ ) if err != nil { return errors.Wrap(err, "failed to rotate AdminUIService cert") @@ -611,19 +704,25 @@ func rotateGeneratedCerts(c base.Config) error { // hostnames and path signed by the ca at the supplied paths. It will only // succeed if it is able to generate these and OVERWRITE an exist file. func (sb *ServiceCertificateBundle) rotateServiceCert( + ctx context.Context, certPath string, keyPath string, serviceCertLifespan time.Duration, - serviceString string, + commonName, serviceString string, hostnames []string, + serviceCertIsAlsoValidAsClient bool, ) error { // generate certPEM, keyPEM, err := security.CreateServiceCertAndKey( + ctx, + log.Ops.Infof, serviceCertLifespan, + commonName, serviceString, hostnames, sb.CACertificate, sb.CAKey, + serviceCertIsAlsoValidAsClient, ) if err != nil { return errors.Wrapf( diff --git a/pkg/server/auto_tls_init_test.go b/pkg/server/auto_tls_init_test.go index 8643bcab7bfa..1c96ec30a3be 100644 --- a/pkg/server/auto_tls_init_test.go +++ b/pkg/server/auto_tls_init_test.go @@ -12,6 +12,7 @@ package server import ( "bytes" + "context" "io" "io/ioutil" "os" @@ -38,13 +39,13 @@ func TestInitializeFromConfig(t *testing.T) { SSLCertsDir: tempDir, } - err = certBundle.InitializeFromConfig(cfg) + err = certBundle.InitializeFromConfig(context.Background(), cfg) if err != nil { t.Fatalf("expected err=nil, got: %q", err) } // Verify certs written to disk match certs in bundles. - bundleFromDisk, err := loadAllCertsFromDisk(cfg) + bundleFromDisk, err := loadAllCertsFromDisk(context.Background(), cfg) if err != nil { t.Fatalf("failed loading certs from disk, got: %q", err) } @@ -61,7 +62,7 @@ func TestInitializeFromConfig(t *testing.T) { } -func loadAllCertsFromDisk(cfg base.Config) (CertificateBundle, error) { +func loadAllCertsFromDisk(ctx context.Context, cfg base.Config) (CertificateBundle, error) { cl := security.MakeCertsLocator(cfg.SSLCertsDir) bundleFromDisk, err := collectLocalCABundle(cfg) if err != nil { @@ -69,7 +70,8 @@ func loadAllCertsFromDisk(cfg base.Config) (CertificateBundle, error) { } err = bundleFromDisk.InterNode.loadOrCreateServiceCertificates( - cl.NodeCertPath(), cl.NodeKeyPath(), "", "", 0, 0, "", []string{}, + ctx, cl.NodeCertPath(), cl.NodeKeyPath(), "", "", 0, 0, security.NodeUser, "", []string{}, + true, /* serviceCertIsAlsoValidAsClient */ ) if err != nil { return bundleFromDisk, err @@ -77,24 +79,27 @@ func loadAllCertsFromDisk(cfg base.Config) (CertificateBundle, error) { // TODO(aaron-crl): Figure out how to handle client auth case. //bundleFromDisk.UserAuth.loadOrCreateServiceCertificates( - // cl.ClientCertPath(), cl.ClientKeyPath(), "", "", 0, "", []string{}, + // ctx, cl.ClientCertPath(), cl.ClientKeyPath(), "", "", 0, 0, security.NodeUser, "", []string{}, //) err = bundleFromDisk.SQLService.loadOrCreateServiceCertificates( - cl.SQLServiceCertPath(), cl.SQLServiceKeyPath(), "", "", 0, 0, "", []string{}, + ctx, cl.SQLServiceCertPath(), cl.SQLServiceKeyPath(), "", "", 0, 0, security.NodeUser, "", []string{}, + false, /* serviceCertIsAlsoValidAsClient */ ) if err != nil { return bundleFromDisk, err } err = bundleFromDisk.RPCService.loadOrCreateServiceCertificates( - cl.RPCServiceCertPath(), cl.RPCServiceKeyPath(), "", "", 0, 0, "", []string{}, + ctx, cl.RPCServiceCertPath(), cl.RPCServiceKeyPath(), "", "", 0, 0, security.NodeUser, "", []string{}, + false, /* serviceCertIsAlsoValidAsClient */ ) if err != nil { return bundleFromDisk, err } err = bundleFromDisk.AdminUIService.loadOrCreateServiceCertificates( - cl.UICertPath(), cl.UIKeyPath(), "", "", 0, 0, "", []string{}, + ctx, cl.UICertPath(), cl.UIKeyPath(), "", "", 0, 0, "fakehost", "", []string{}, + false, /* serviceCertIsAlsoValidAsClient */ ) if err != nil { return bundleFromDisk, err @@ -219,7 +224,7 @@ func TestDummyInitializeNodeFromBundle(t *testing.T) { SSLCertsDir: tempDir, } - err = certBundle.InitializeNodeFromBundle(cfg) + err = certBundle.InitializeNodeFromBundle(context.Background(), cfg) if err != nil { t.Fatalf("expected err=nil, got: %s", err) } @@ -233,7 +238,9 @@ func TestDummyCertLoader(t *testing.T) { scb := ServiceCertificateBundle{} _ = scb.loadServiceCertAndKey("", "") _ = scb.loadCACertAndKey("", "") - _ = scb.rotateServiceCert("", "", time.Minute, "", []string{""}) + _ = scb.rotateServiceCert(context.Background(), "", "", time.Minute, "fakehost", "", []string{""}, + false, /* serviceCertIsAlsoValidAsClient */ + ) } // TestNodeCertRotation tests that the rotation function will overwrite the @@ -272,7 +279,9 @@ func TestRotationOnUnintializedNode(t *testing.T) { t.Fatal("files added to cfg.SSLCertsDir when they shouldn't have been") } - err = rotateGeneratedCerts(cfg) + // TODO(aaron-crl): Verify that the certs are rotated for the proper + // addresses. + err = rotateGeneratedCerts(context.Background(), cfg) if err != nil { t.Fatalf("expected nil error generating no certs, got: %q", err) } @@ -296,21 +305,24 @@ func TestRotationOnIntializedNode(t *testing.T) { cfg := base.Config{ SSLCertsDir: tempDir, } + ctx := context.Background() // Test in the fully provisioned case. certBundle := CertificateBundle{} - err = certBundle.InitializeFromConfig(cfg) + err = certBundle.InitializeFromConfig(ctx, cfg) if err != nil { t.Fatalf("expected err=nil, got: %q", err) } - err = rotateGeneratedCerts(cfg) + // TODO(aaron-crl): Verify that the certs are rotated for the proper + // addresses. + err = rotateGeneratedCerts(ctx, cfg) if err != nil { t.Fatalf("rotation failed; expected err=nil, got: %q", err) } // Verify that any existing certs have changed on disk for services - diskBundle, err := loadAllCertsFromDisk(cfg) + diskBundle, err := loadAllCertsFromDisk(ctx, cfg) if err != nil { t.Fatalf("failed loading certs from disk, got: %q", err) } @@ -334,9 +346,11 @@ func TestRotationOnPartialIntializedNode(t *testing.T) { cfg := base.Config{ SSLCertsDir: tempDir, } + ctx := context.Background() + // Test in the partially provisioned case (remove the Client and UI CAs). certBundle := CertificateBundle{} - err = certBundle.InitializeFromConfig(cfg) + err = certBundle.InitializeFromConfig(ctx, cfg) if err != nil { t.Fatalf("expected err=nil, got: %q", err) } @@ -355,14 +369,17 @@ func TestRotationOnPartialIntializedNode(t *testing.T) { t.Fatalf("failed to remove test cert: %q", err) } + // TODO(aaron-crl): Verify that the certs are rotated for the proper + // addresses. + // This should rotate all service certs except client and UI. - err = rotateGeneratedCerts(cfg) + err = rotateGeneratedCerts(ctx, cfg) if err != nil { t.Fatalf("rotation failed; expected err=nil, got: %q", err) } // Verify that client and UI service host certs are unchanged. - diskBundle, err := loadAllCertsFromDisk(cfg) + diskBundle, err := loadAllCertsFromDisk(ctx, cfg) if err != nil { t.Fatalf("failed loading certs from disk, got: %q", err) } @@ -399,9 +416,11 @@ func TestRotationOnBrokenIntializedNode(t *testing.T) { cfg := base.Config{ SSLCertsDir: tempDir, } + ctx := context.Background() + cl := security.MakeCertsLocator(cfg.SSLCertsDir) certBundle := CertificateBundle{} - err = certBundle.InitializeFromConfig(cfg) + err = certBundle.InitializeFromConfig(ctx, cfg) if err != nil { t.Fatalf("expected err=nil, got: %q", err) } @@ -414,7 +433,7 @@ func TestRotationOnBrokenIntializedNode(t *testing.T) { t.Fatalf("failed to remove test cert: %q", err) } - err = rotateGeneratedCerts(cfg) + err = rotateGeneratedCerts(ctx, cfg) if err == nil { t.Fatalf("rotation succeeded but should have failed with missing leaf certs for SQLService") } diff --git a/pkg/server/init_handshake.go b/pkg/server/init_handshake.go index fdccf9681ecf..5642ee29ba2e 100644 --- a/pkg/server/init_handshake.go +++ b/pkg/server/init_handshake.go @@ -30,6 +30,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/util/contextutil" "github.com/cockroachdb/cockroach/pkg/util/log" "github.com/cockroachdb/errors" + "github.com/cockroachdb/logtags" ) const ( @@ -128,13 +129,27 @@ func pemToSignature(caCertPEM []byte) ([]byte, error) { } func createNodeInitTempCertificates( - hostnames []string, lifespan time.Duration, + ctx context.Context, hostnames []string, lifespan time.Duration, ) (certs ServiceCertificateBundle, err error) { - caCert, caKey, err := security.CreateCACertAndKey(lifespan, initServiceName) + log.Ops.Infof(ctx, "creating temporary initial certificates for hosts %+v, duration %s", hostnames, lifespan) + + caCtx := logtags.AddTag(ctx, "create-temp-ca", nil) + caCert, caKey, err := security.CreateCACertAndKey(caCtx, log.Ops.Infof, lifespan, initServiceName) if err != nil { return certs, err } - serviceCert, serviceKey, err := security.CreateServiceCertAndKey(lifespan, initServiceName, hostnames, caCert, caKey) + serviceCtx := logtags.AddTag(ctx, "create-temp-service", nil) + serviceCert, serviceKey, err := security.CreateServiceCertAndKey( + serviceCtx, + log.Ops.Infof, + lifespan, + security.NodeUser, + initServiceName, + hostnames, + caCert, + caKey, + false, /* serviceCertIsAlsoValidAsClient */ + ) if err != nil { return certs, err } @@ -150,7 +165,7 @@ func createNodeInitTempCertificates( func sendBadRequestError(ctx context.Context, err error, w http.ResponseWriter) { http.Error(w, "invalid request message", http.StatusBadRequest) - log.Warningf(ctx, "bad request: %s", err) + log.Ops.Warningf(ctx, "bad request: %s", err) } func generateURLForClient(peer string, endpoint string) string { @@ -159,7 +174,6 @@ func generateURLForClient(peer string, endpoint string) string { // tlsInitHandshaker takes in a list of peers type tlsInitHandshaker struct { - ctx context.Context server *http.Server token []byte @@ -173,7 +187,7 @@ type tlsInitHandshaker struct { wg sync.WaitGroup } -func (t *tlsInitHandshaker) init() error { +func (t *tlsInitHandshaker) init(ctx context.Context) error { serverCert, err := tls.X509KeyPair(t.tempCerts.HostCertificate, t.tempCerts.HostKey) if err != nil { return err @@ -189,8 +203,8 @@ func (t *tlsInitHandshaker) init() error { } mux := http.NewServeMux() - mux.HandleFunc(trustInitURL, t.onTrustInit) - mux.HandleFunc(deliverBundleURL, t.onDeliverBundle) + mux.HandleFunc(trustInitURL, enhanceHandlerContextWithHTTPClient(ctx, t.onTrustInit)) + mux.HandleFunc(deliverBundleURL, enhanceHandlerContextWithHTTPClient(ctx, t.onDeliverBundle)) t.server = &http.Server{ Addr: t.listenAddr, @@ -200,20 +214,31 @@ func (t *tlsInitHandshaker) init() error { return nil } -// Handler for initial challenge and ack containing the ephemeral node CAs -func (t *tlsInitHandshaker) onTrustInit(res http.ResponseWriter, req *http.Request) { +func enhanceHandlerContextWithHTTPClient( + baseCtx context.Context, fn func(ctx context.Context, w http.ResponseWriter, req *http.Request), +) http.HandlerFunc { + return func(w http.ResponseWriter, req *http.Request) { + ctx := logtags.AddTag(baseCtx, "peer", req.RemoteAddr) + fn(ctx, w, req) + } +} + +// Handler for initial challenge and ack containing the ephemeral node CAs. +func (t *tlsInitHandshaker) onTrustInit( + ctx context.Context, res http.ResponseWriter, req *http.Request, +) { var challenge nodeHostnameAndCA // TODO(aaron-crl): [Security] Make this more error resilient to size and shape attacks. err := json.NewDecoder(req.Body).Decode(&challenge) if err != nil { - sendBadRequestError(t.ctx, errors.Wrap(err, "error when unmarshalling challenge"), res) + sendBadRequestError(ctx, errors.Wrap(err, "error when unmarshalling challenge"), res) return } defer req.Body.Close() if !challenge.validHMAC(t.token) { - sendBadRequestError(t.ctx, errInvalidHMAC, res) + sendBadRequestError(ctx, errInvalidHMAC, res) // Non-blocking channel send. select { case t.errors <- errInvalidHMAC: @@ -223,6 +248,8 @@ func (t *tlsInitHandshaker) onTrustInit(res http.ResponseWriter, req *http.Reque return } + log.Ops.Infof(ctx, "received valid challenge and CA from: %s", challenge.HostAddress) + t.trustedPeers <- challenge // Acknowledge validation to the client. @@ -238,27 +265,33 @@ func (t *tlsInitHandshaker) onTrustInit(res http.ResponseWriter, req *http.Reque } // Handler to allow peer to deliver internode CA trust bundle. -func (t *tlsInitHandshaker) onDeliverBundle(res http.ResponseWriter, req *http.Request) { +func (t *tlsInitHandshaker) onDeliverBundle( + ctx context.Context, res http.ResponseWriter, req *http.Request, +) { bundle := nodeTrustBundle{} err := json.NewDecoder(req.Body).Decode(&bundle) defer req.Body.Close() if err != nil { - sendBadRequestError(t.ctx, errors.Wrap(err, "error when unmarshalling bundle"), res) + sendBadRequestError(ctx, errors.Wrap(err, "error when unmarshalling bundle"), res) return } - if bundle.validHMAC(t.token) { - // Successfully provisioned. - t.finishedInit <- &bundle.Bundle - close(t.finishedInit) + if !bundle.validHMAC(t.token) { + sendBadRequestError(ctx, errors.New("invalid bundle HMAC"), res) + return + } + + log.Ops.Infof(ctx, "received valid cert bundle from trust leader") + // Successfully provisioned. + select { + case t.finishedInit <- &bundle.Bundle: + // Done. + case <-ctx.Done(): + log.Ops.Warningf(ctx, "context canceled while receiving bundle") } } -func (t *tlsInitHandshaker) startServer(listener net.Listener) { - go func() { - // Start the server. - _ = t.server.ServeTLS(listener, "", "") - t.wg.Done() - }() +func (t *tlsInitHandshaker) startServer(listener net.Listener) error { + return t.server.ServeTLS(listener, "", "") } func (t *tlsInitHandshaker) stopServer() { @@ -335,8 +368,9 @@ func (t *tlsInitHandshaker) getPeerCACert( return msg, nil } -func (t *tlsInitHandshaker) runClient(peerHostname string, selfAddress string) { - defer t.wg.Done() +func (t *tlsInitHandshaker) runClient( + ctx context.Context, peerHostname string, selfAddress string, +) { // Sleep for 500ms between attempts. ticker := time.NewTicker(500 * time.Millisecond) defer ticker.Stop() @@ -345,13 +379,14 @@ func (t *tlsInitHandshaker) runClient(peerHostname string, selfAddress string) { for { select { - case <-t.ctx.Done(): + case <-ctx.Done(): return case <-ticker.C: } peerHostnameAndCa, err := t.getPeerCACert(client, peerHostname, selfAddress) if err != nil { + log.Ops.Warningf(ctx, "peer CA retrieval error: %v", err) // Non-blocking channel send. select { case t.errors <- err: @@ -364,14 +399,14 @@ func (t *tlsInitHandshaker) runClient(peerHostname string, selfAddress string) { } select { case t.trustedPeers <- peerHostnameAndCa: - case <-t.ctx.Done(): + case <-ctx.Done(): } return } } func (t *tlsInitHandshaker) sendBundle( - address string, peerCACert []byte, caBundle nodeTrustBundle, + ctx context.Context, address string, peerCACert []byte, caBundle nodeTrustBundle, ) (err error) { rootCAs, _ := x509.SystemCertPool() rootCAs.AppendCertsFromPEM(peerCACert) @@ -384,12 +419,14 @@ func (t *tlsInitHandshaker) sendBundle( return err } + every := log.Every(time.Second) + ticker := time.NewTicker(500 * time.Millisecond) defer ticker.Stop() var lastError error for { select { - case <-t.ctx.Done(): + case <-ctx.Done(): if lastError != nil { return lastError } @@ -402,6 +439,9 @@ func (t *tlsInitHandshaker) sendBundle( break } lastError = err + if every.ShouldLog() { + log.Ops.Warningf(ctx, "cannot send bundle: %v", err) + } } return nil @@ -409,13 +449,22 @@ func (t *tlsInitHandshaker) sendBundle( func initHandshakeHelper( ctx context.Context, + reporter func(string, ...interface{}), cfg *base.Config, token string, - numExpectedPeers int, + numExpectedNodes int, peers []string, certsDir string, listener net.Listener, ) error { + if len(token) == 0 { + return errors.AssertionFailedf("programming error: token cannot be empty") + } + if numExpectedNodes <= 0 { + return errors.AssertionFailedf("programming error: must expect more than 1 node") + } + numExpectedPeers := numExpectedNodes - 1 + addr := listener.Addr() var listenHost string switch netAddr := addr.(type) { @@ -424,13 +473,13 @@ func initHandshakeHelper( default: return errors.New("unsupported listener protocol: only TCP listeners supported") } - tempCerts, err := createNodeInitTempCertificates([]string{listenHost}, defaultInitLifespan) + tempCerts, err := createNodeInitTempCertificates(ctx, []string{listenHost}, defaultInitLifespan) if err != nil { return errors.Wrap(err, "failed to create certificates") } + log.Infof(ctx, "initializing temporary TLS handshake server, listen addr: %s", addr) handshaker := &tlsInitHandshaker{ - ctx: ctx, token: []byte(token), certsDir: certsDir, listenAddr: addr.String(), @@ -439,39 +488,68 @@ func initHandshakeHelper( errors: make(chan error, numExpectedPeers*2), finishedInit: make(chan *CertificateBundle, 1), } - if err := handshaker.init(); err != nil { + if err := handshaker.init(ctx); err != nil { return errors.Wrap(err, "error when initializing tls handshaker") } - // Add to waitGroup for every client (= len(peers)) and server (= 1) goroutine - // instantiated. The calls to wg.Done() are made by the server/client - // goroutines themselves. - handshaker.wg.Add(len(peers) + 1) + // Wait for the server and all the clients to terminate before returning. defer handshaker.wg.Wait() - handshaker.startServer(listener) - defer handshaker.stopServer() + peerCACerts := make(map[string]([]byte)) - for _, peerAddress := range peers { - go handshaker.runClient(peerAddress, addr.String()) - } + if numExpectedPeers > 0 { + handshaker.wg.Add(1) + go func() { + defer handshaker.wg.Done() - // Wait until we have numExpectedPeers peer certificates. - peerCACerts := make(map[string]([]byte)) - for len(peerCACerts) < numExpectedPeers { - select { - case p := <-handshaker.trustedPeers: - peerCACerts[p.HostAddress] = p.CACertificate - case err := <-handshaker.errors: - if errors.Is(err, errInvalidHMAC) { - // Either this peer, or another peer, has the wrong token. Fail - // fast. - return errors.New("invalid signature in messages from peers; likely due to token mismatch") + log.Ops.Infof(ctx, "starting handshake server") + defer log.Ops.Infof(ctx, "handshake server stopped") + if err := handshaker.startServer(listener); !errors.Is(err, http.ErrServerClosed) { + log.Ops.Errorf(ctx, "handshake server failed: %v", err) + } + }() + // Terminate the server before exiting. + defer handshaker.stopServer() + + for _, peerAddress := range peers { + handshaker.wg.Add(1) + go func(peerAddress string) { + defer handshaker.wg.Done() + + peerCtx := logtags.AddTag(ctx, "peer", peerAddress) + log.Ops.Infof(peerCtx, "starting handshake client for peer") + handshaker.runClient(peerCtx, peerAddress, addr.String()) + }(peerAddress) + } + + if reporter != nil { + reporter("waiting for handshake for %d peers", numExpectedPeers) + } + + // Wait until we have numExpectedPeers peer certificates. + for len(peerCACerts) < numExpectedPeers { + select { + case p := <-handshaker.trustedPeers: + log.Ops.Infof(ctx, "received CA certificate for peer: %s", p.HostAddress) + if reporter != nil { + reporter("trusted peer: %s", p.HostAddress) + } + peerCACerts[p.HostAddress] = p.CACertificate + + case err := <-handshaker.errors: + if errors.Is(err, errInvalidHMAC) { + // Either this peer, or another peer, has the wrong token. Fail + // fast. + log.Ops.Errorf(ctx, "HMAC error from client when connecting to peer: %v", err) + return errors.New("invalid signature in messages from peers; likely due to token mismatch") + } + log.Ops.Warningf(ctx, "error from client when connecting to peers (retrying): %s", err) + + case <-ctx.Done(): + return errors.New("context canceled before all peers connected") } - log.Errorf(ctx, "error from client when connecting to peers (retrying): %s", err) - case <-ctx.Done(): - return errors.New("context canceled before peers connected") } + log.Ops.Infof(ctx, "received response from all peers; choosing trust leader") } // Order nodes by certificates. @@ -494,10 +572,16 @@ func initHandshakeHelper( // Initialize if this node is the trust leader. If not, wait for trust bundle // to come from another node. if trustLeader { + if reporter != nil { + reporter("generating cert bundle for cluster") + } + log.Ops.Infof(ctx, "we are trust leader; initializing certificate bundle") + leaderCtx := logtags.AddTag(ctx, "trust-leader", nil) + var b CertificateBundle // TODO(bilal): See if we can get rid of the need to store a base.Config // pointer. This is the only place in this method where it is necessary. - if err := b.InitializeFromConfig(*cfg); err != nil { + if err := b.InitializeFromConfig(leaderCtx, *cfg); err != nil { return errors.Wrap(err, "error when creating initialization bundle") } @@ -508,9 +592,16 @@ func initHandshakeHelper( trustBundle := nodeTrustBundle{Bundle: peerInit} trustBundle.signHMAC(handshaker.token) + + if reporter != nil { + reporter("sending cert bundle to peers") + } + // For each peer, use its CA to establish a secure connection and deliver the trust bundle. for p := range peerCACerts { - if err := handshaker.sendBundle(p, peerCACerts[p], trustBundle); err != nil { + peerCtx := logtags.AddTag(leaderCtx, "peer", p) + log.Ops.Infof(peerCtx, "delivering bundle to peer") + if err := handshaker.sendBundle(peerCtx, p, peerCACerts[p], trustBundle); err != nil { // TODO(bilal): sendBundle should fail fast instead of retrying (or // waiting for ctx cancellation) if the error returned is due to a // mismatching CA cert than peerCACerts[p]. This would likely mean @@ -522,12 +613,23 @@ func initHandshakeHelper( return nil } + if reporter != nil { + reporter("waiting for cert bundle") + } + log.Ops.Infof(ctx, "we are not trust leader; now waiting for bundle from trust leader") + select { case b := <-handshaker.finishedInit: + if reporter != nil { + reporter("received cert bundle") + } + if b == nil { return errors.New("expected non-nil init bundle to be received from trust leader") } - return b.InitializeNodeFromBundle(*cfg) + log.Ops.Infof(ctx, "received bundle, now initializing node certificate files") + return b.InitializeNodeFromBundle(ctx, *cfg) + case <-ctx.Done(): return errors.New("context canceled before init bundle received from leader") } @@ -538,9 +640,10 @@ func initHandshakeHelper( // negotiates an inter-node CA and puts it in certsDir. func InitHandshake( ctx context.Context, + reporter func(string, ...interface{}), cfg *base.Config, token string, - numExpectedPeers int, + numExpectedNodes int, peers []string, certsDir string, listener net.Listener, @@ -548,6 +651,7 @@ func InitHandshake( // TODO(bilal): Allow defaultInitLifespan to be configurable, possibly through // base.Config. return contextutil.RunWithTimeout(ctx, "init handshake", defaultInitLifespan, func(ctx context.Context) error { - return initHandshakeHelper(ctx, cfg, token, numExpectedPeers, peers, certsDir, listener) + ctx = logtags.AddTag(ctx, "init-tls-handshake", nil) + return initHandshakeHelper(ctx, reporter, cfg, token, numExpectedNodes, peers, certsDir, listener) }) } diff --git a/pkg/server/init_handshake_test.go b/pkg/server/init_handshake_test.go index bc06c23d6867..854b7d1e3d2d 100644 --- a/pkg/server/init_handshake_test.go +++ b/pkg/server/init_handshake_test.go @@ -12,7 +12,6 @@ package server import ( "context" - "net" "os" "path" "testing" @@ -20,15 +19,20 @@ import ( "github.com/cockroachdb/cockroach/pkg/base" "github.com/cockroachdb/cockroach/pkg/testutils" + "github.com/cockroachdb/cockroach/pkg/testutils/skip" "github.com/cockroachdb/cockroach/pkg/util/leaktest" "github.com/cockroachdb/cockroach/pkg/util/log" "github.com/cockroachdb/errors" + "github.com/cockroachdb/logtags" "github.com/stretchr/testify/require" ) func TestInitHandshake(t *testing.T) { defer leaktest.AfterTest(t)() defer log.Scope(t).Close(t) + + skip.UnderShort(t) + timeout := 11 * time.Minute ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() @@ -36,20 +40,29 @@ func TestInitHandshake(t *testing.T) { tempDir, del := testutils.TempDir(t) defer del() + ctx1 := logtags.AddTag(ctx, "n", 1) cfg1 := &base.Config{} cfg1.InitDefaults() cfg1.SSLCertsDir = path.Join(tempDir, "temp1") + cfg1.Addr = "127.0.0.1:0" require.NoError(t, os.Mkdir(cfg1.SSLCertsDir, 0755)) + require.NoError(t, cfg1.ValidateAddrs(ctx1)) + ctx2 := logtags.AddTag(ctx, "n", 2) cfg2 := &base.Config{} cfg2.InitDefaults() cfg2.SSLCertsDir = path.Join(tempDir, "temp2") + cfg2.Addr = "127.0.0.1:0" require.NoError(t, os.Mkdir(cfg2.SSLCertsDir, 0755)) + require.NoError(t, cfg2.ValidateAddrs(ctx2)) + ctx3 := logtags.AddTag(ctx, "n", 3) cfg3 := &base.Config{} cfg3.InitDefaults() cfg3.SSLCertsDir = path.Join(tempDir, "temp3") + cfg3.Addr = "127.0.0.1:0" require.NoError(t, os.Mkdir(cfg3.SSLCertsDir, 0755)) + require.NoError(t, cfg3.ValidateAddrs(ctx3)) errReturned := make(chan error, 1) // Do a three-node handshake, and ensure no error is returned. The errors @@ -57,35 +70,41 @@ func TestInitHandshake(t *testing.T) { // not be empty. var addr1, addr2, addr3 string - listener1, err := net.Listen("tcp4", "127.0.0.1:0") + listener1, err := ListenAndUpdateAddrs(ctx1, &cfg1.Addr, &cfg1.AdvertiseAddr, "rpc") require.NoError(t, err) defer func() { _ = listener1.Close() }() addr1 = listener1.Addr().String() - listener2, err := net.Listen("tcp4", "127.0.0.1:0") + listener2, err := ListenAndUpdateAddrs(ctx2, &cfg2.Addr, &cfg2.AdvertiseAddr, "rpc") require.NoError(t, err) defer func() { _ = listener2.Close() }() addr2 = listener2.Addr().String() - listener3, err := net.Listen("tcp4", "127.0.0.1:0") + listener3, err := ListenAndUpdateAddrs(ctx3, &cfg3.Addr, &cfg3.AdvertiseAddr, "rpc") require.NoError(t, err) defer func() { _ = listener3.Close() }() addr3 = listener3.Addr().String() + reporter := func(prefix string) func(string, ...interface{}) { + return func(format string, args ...interface{}) { + t.Logf(prefix+": "+format, args...) + } + } + go func() { - errReturned <- InitHandshake(ctx, cfg1, "foobar", 2, []string{addr2, addr3}, cfg1.SSLCertsDir, listener1) + errReturned <- InitHandshake(ctx1, reporter("n1"), cfg1, "foobar", 3, []string{addr2, addr3}, cfg1.SSLCertsDir, listener1) }() go func() { - errReturned <- InitHandshake(ctx, cfg2, "foobar", 2, []string{addr1, addr3}, cfg1.SSLCertsDir, listener2) + errReturned <- InitHandshake(ctx2, reporter("n2"), cfg2, "foobar", 3, []string{addr1, addr3}, cfg2.SSLCertsDir, listener2) }() go func() { - errReturned <- InitHandshake(ctx, cfg3, "foobar", 2, []string{addr1, addr2}, cfg1.SSLCertsDir, listener3) + errReturned <- InitHandshake(ctx3, reporter("n3"), cfg3, "foobar", 3, []string{addr1, addr2}, cfg3.SSLCertsDir, listener3) }() count := 0 @@ -116,64 +135,83 @@ func TestInitHandshake(t *testing.T) { func TestInitHandshakeWrongToken(t *testing.T) { defer leaktest.AfterTest(t)() defer log.Scope(t).Close(t) + + skip.UnderShort(t) + // The test deadline needs to be greater than this for this test to pass, // as one of the nodes will have to wait for this context to time out. - timeout := 30 * time.Second + timeout := 20 * time.Second ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() tempDir, del := testutils.TempDir(t) defer del() + ctx1 := logtags.AddTag(ctx, "n", 1) cfg1 := &base.Config{} cfg1.InitDefaults() cfg1.SSLCertsDir = path.Join(tempDir, "temp1") + cfg1.Addr = "127.0.0.1:0" require.NoError(t, os.Mkdir(cfg1.SSLCertsDir, 0755)) + require.NoError(t, cfg1.ValidateAddrs(ctx1)) + ctx2 := logtags.AddTag(ctx, "n", 2) cfg2 := &base.Config{} cfg2.InitDefaults() cfg2.SSLCertsDir = path.Join(tempDir, "temp2") + cfg2.Addr = "127.0.0.1:0" require.NoError(t, os.Mkdir(cfg2.SSLCertsDir, 0755)) + require.NoError(t, cfg2.ValidateAddrs(ctx2)) + ctx3 := logtags.AddTag(ctx, "n", 3) cfg3 := &base.Config{} cfg3.InitDefaults() cfg3.SSLCertsDir = path.Join(tempDir, "temp3") + cfg3.Addr = "127.0.0.1:0" require.NoError(t, os.Mkdir(cfg3.SSLCertsDir, 0755)) + require.NoError(t, cfg3.ValidateAddrs(ctx3)) errReturned := make(chan error, 1) - // Do a three-node handshake, with one node having the wrong token. At least - // one of the three errors returned should be non-nil. + // Do a three-node handshake, and ensure no error is returned. The errors + // returned should be nil, and one of the temp SSL certs directories should + // not be empty. var addr1, addr2, addr3 string - listener1, err := net.Listen("tcp4", "127.0.0.1:0") + listener1, err := ListenAndUpdateAddrs(ctx1, &cfg1.Addr, &cfg1.AdvertiseAddr, "rpc") require.NoError(t, err) defer func() { _ = listener1.Close() }() addr1 = listener1.Addr().String() - listener2, err := net.Listen("tcp4", "127.0.0.1:0") + listener2, err := ListenAndUpdateAddrs(ctx2, &cfg2.Addr, &cfg2.AdvertiseAddr, "rpc") require.NoError(t, err) defer func() { _ = listener2.Close() }() addr2 = listener2.Addr().String() - listener3, err := net.Listen("tcp4", "127.0.0.1:0") + listener3, err := ListenAndUpdateAddrs(ctx3, &cfg3.Addr, &cfg3.AdvertiseAddr, "rpc") require.NoError(t, err) defer func() { _ = listener3.Close() }() addr3 = listener3.Addr().String() + reporter := func(prefix string) func(string, ...interface{}) { + return func(format string, args ...interface{}) { + t.Logf(prefix+": "+format, args...) + } + } + go func() { - errReturned <- InitHandshake(ctx, cfg1, "foobar", 2, []string{addr2, addr3}, cfg1.SSLCertsDir, listener1) + errReturned <- InitHandshake(ctx1, reporter("n1"), cfg1, "foobar", 3, []string{addr2, addr3}, cfg1.SSLCertsDir, listener1) }() go func() { - errReturned <- InitHandshake(ctx, cfg2, "foobarbaz", 2, []string{addr1, addr3}, cfg1.SSLCertsDir, listener2) + errReturned <- InitHandshake(ctx2, reporter("n2"), cfg2, "foobarbaz", 3, []string{addr1, addr3}, cfg2.SSLCertsDir, listener2) }() go func() { - errReturned <- InitHandshake(ctx, cfg3, "foobar", 2, []string{addr1, addr2}, cfg1.SSLCertsDir, listener3) + errReturned <- InitHandshake(ctx3, reporter("n3"), cfg3, "foobar", 3, []string{addr1, addr2}, cfg3.SSLCertsDir, listener3) }() count := 0 diff --git a/pkg/server/server.go b/pkg/server/server.go index 881ba5f68c33..3e739e31a07e 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -1868,7 +1868,7 @@ func (s *Server) startListenRPCAndSQL( } if ln == nil { var err error - ln, err = listen(ctx, &s.cfg.Addr, &s.cfg.AdvertiseAddr, rpcChanName) + ln, err = ListenAndUpdateAddrs(ctx, &s.cfg.Addr, &s.cfg.AdvertiseAddr, rpcChanName) if err != nil { return nil, nil, err } @@ -1877,7 +1877,7 @@ func (s *Server) startListenRPCAndSQL( var pgL net.Listener if s.cfg.SplitListenSQL { - pgL, err = listen(ctx, &s.cfg.SQLAddr, &s.cfg.SQLAdvertiseAddr, "sql") + pgL, err = ListenAndUpdateAddrs(ctx, &s.cfg.SQLAddr, &s.cfg.SQLAdvertiseAddr, "sql") if err != nil { return nil, nil, err } @@ -1970,7 +1970,7 @@ func (s *Server) startListenRPCAndSQL( func (s *Server) startServeUI( ctx, workersCtx context.Context, connManager netutil.Server, uiTLSConfig *tls.Config, ) error { - httpLn, err := listen(ctx, &s.cfg.HTTPAddr, &s.cfg.HTTPAdvertiseAddr, "http") + httpLn, err := ListenAndUpdateAddrs(ctx, &s.cfg.HTTPAddr, &s.cfg.HTTPAdvertiseAddr, "http") if err != nil { return err } @@ -2484,7 +2484,11 @@ type tcpKeepAliveManager struct { loggedKeepAliveStatus int32 } -func listen( +// ListenAndUpdateAddrs starts a TCP listener on the specified address +// then updates the address and advertised address fields based on the +// actual interface address resolved by the OS during the Listen() +// call. +func ListenAndUpdateAddrs( ctx context.Context, addr, advertiseAddr *string, connName string, ) (net.Listener, error) { ln, err := net.Listen("tcp", *addr) diff --git a/pkg/server/testserver.go b/pkg/server/testserver.go index 6049a98a0778..082bf105583b 100644 --- a/pkg/server/testserver.go +++ b/pkg/server/testserver.go @@ -781,7 +781,7 @@ func StartTenant( SetupIdleMonitor(ctx, args.stopper, baseCfg.IdleExitAfter, connManager) } - pgL, err := listen(ctx, &args.Config.SQLAddr, &args.Config.SQLAdvertiseAddr, "sql") + pgL, err := ListenAndUpdateAddrs(ctx, &args.Config.SQLAddr, &args.Config.SQLAdvertiseAddr, "sql") if err != nil { return nil, "", "", err } @@ -801,7 +801,7 @@ func StartTenant( } } - httpL, err := listen(ctx, &args.Config.HTTPAddr, &args.Config.HTTPAdvertiseAddr, "http") + httpL, err := ListenAndUpdateAddrs(ctx, &args.Config.HTTPAddr, &args.Config.HTTPAdvertiseAddr, "http") if err != nil { return nil, "", "", err } diff --git a/pkg/storage/pebble.go b/pkg/storage/pebble.go index 358313330404..a85e0e1d3557 100644 --- a/pkg/storage/pebble.go +++ b/pkg/storage/pebble.go @@ -1072,8 +1072,17 @@ func (p *Pebble) NewBatch() Batch { // NewReadOnly implements the Engine interface. func (p *Pebble) NewReadOnly() ReadWriter { + // TODO(sumeer): a sync.Pool for pebbleReadOnly would save on allocations + // for the underlying pebbleIterators. return &pebbleReadOnly{ parent: p, + // Defensively set reusable=true. One has to be careful about this since + // an accidental false value would cause these iterators, that are value + // members of pebbleReadOnly, to be put in the pebbleIterPool. + prefixIter: pebbleIterator{reusable: true}, + normalIter: pebbleIterator{reusable: true}, + prefixEngineIter: pebbleIterator{reusable: true}, + normalEngineIter: pebbleIterator{reusable: true}, } } @@ -1368,14 +1377,14 @@ func (p *pebbleReadOnly) NewMVCCIterator(iterKind MVCCIterKind, opts IterOptions if iter.inuse { panic("iterator already in use") } + // Ensures no timestamp hints etc. + checkOptionsForIterReuse(opts) if iter.iter != nil { - iter.setOptions(opts) + iter.setBounds(opts.LowerBound, opts.UpperBound) } else { iter.init(p.parent.db, p.iter, opts) - // The timestamp hints should be empty given the earlier code, but we are - // being defensive. - if p.iter == nil && opts.MaxTimestampHint.IsEmpty() && opts.MinTimestampHint.IsEmpty() { + if p.iter == nil { // For future cloning. p.iter = iter.iter } @@ -1403,14 +1412,14 @@ func (p *pebbleReadOnly) NewEngineIterator(opts IterOptions) EngineIterator { if iter.inuse { panic("iterator already in use") } + // Ensures no timestamp hints etc. + checkOptionsForIterReuse(opts) if iter.iter != nil { - iter.setOptions(opts) + iter.setBounds(opts.LowerBound, opts.UpperBound) } else { iter.init(p.parent.db, p.iter, opts) - // The timestamp hints should be empty given this is an EngineIterator, - // but we are being defensive. - if p.iter == nil && opts.MaxTimestampHint.IsEmpty() && opts.MinTimestampHint.IsEmpty() { + if p.iter == nil { // For future cloning. p.iter = iter.iter } @@ -1421,6 +1430,18 @@ func (p *pebbleReadOnly) NewEngineIterator(opts IterOptions) EngineIterator { return iter } +// checkOptionsForIterReuse checks that the options are appropriate for +// iterators that are reusable, and panics if not. This includes disallowing +// any timestamp hints. +func checkOptionsForIterReuse(opts IterOptions) { + if !opts.MinTimestampHint.IsEmpty() || !opts.MaxTimestampHint.IsEmpty() { + panic("iterator with timestamp hints cannot be reused") + } + if !opts.Prefix && len(opts.UpperBound) == 0 && len(opts.LowerBound) == 0 { + panic("iterator must set prefix or upper bound or lower bound") + } +} + // ConsistentIterators implements the Engine interface. func (p *pebbleReadOnly) ConsistentIterators() bool { return true diff --git a/pkg/storage/pebble_batch.go b/pkg/storage/pebble_batch.go index 3e73b3d6256e..b33bc7e6c1d1 100644 --- a/pkg/storage/pebble_batch.go +++ b/pkg/storage/pebble_batch.go @@ -229,18 +229,18 @@ func (p *pebbleBatch) NewMVCCIterator(iterKind MVCCIterKind, opts IterOptions) M if iter.inuse { panic("iterator already in use") } + // Ensures no timestamp hints etc. + checkOptionsForIterReuse(opts) if iter.iter != nil { - iter.setOptions(opts) + iter.setBounds(opts.LowerBound, opts.UpperBound) } else { if p.batch.Indexed() { iter.init(p.batch, p.iter, opts) } else { iter.init(p.db, p.iter, opts) } - // The timestamp hints should be empty given the earlier code, but we are - // being defensive. - if p.iter == nil && opts.MaxTimestampHint.IsEmpty() && opts.MinTimestampHint.IsEmpty() { + if p.iter == nil { // For future cloning. p.iter = iter.iter } @@ -271,18 +271,18 @@ func (p *pebbleBatch) NewEngineIterator(opts IterOptions) EngineIterator { if iter.inuse { panic("iterator already in use") } + // Ensures no timestamp hints etc. + checkOptionsForIterReuse(opts) if iter.iter != nil { - iter.setOptions(opts) + iter.setBounds(opts.LowerBound, opts.UpperBound) } else { if p.batch.Indexed() { iter.init(p.batch, p.iter, opts) } else { iter.init(p.db, p.iter, opts) } - // The timestamp hints should be empty given this is an EngineIterator, - // but we are being defensive. - if p.iter == nil && opts.MaxTimestampHint.IsEmpty() && opts.MinTimestampHint.IsEmpty() { + if p.iter == nil { // For future cloning. p.iter = iter.iter } diff --git a/pkg/storage/pebble_iterator.go b/pkg/storage/pebble_iterator.go index f46917eb4ffd..dc328ac0d6f6 100644 --- a/pkg/storage/pebble_iterator.go +++ b/pkg/storage/pebble_iterator.go @@ -38,9 +38,11 @@ type pebbleIterator struct { // use two slices for each of the bounds since this caller should not change // the slice holding the current bounds, that the callee (pebble.MVCCIterator) // is currently using, until after the caller has made the SetBounds call. - lowerBoundBuf [2][]byte - upperBoundBuf [2][]byte - curBuf int + lowerBoundBuf [2][]byte + upperBoundBuf [2][]byte + curBuf int + testingSetBoundsListener testingSetBoundsListener + // Set to true to govern whether to call SeekPrefixGE or SeekGE. Skips // SSTables based on MVCC/Engine key when true. prefix bool @@ -75,11 +77,16 @@ type cloneableIter interface { Clone() (*pebble.Iterator, error) } +type testingSetBoundsListener interface { + postSetBounds(lower, upper []byte) +} + // Instantiates a new Pebble iterator, or gets one from the pool. func newPebbleIterator( handle pebble.Reader, iterToClone cloneableIter, opts IterOptions, ) *pebbleIterator { iter := pebbleIterPool.Get().(*pebbleIterator) + iter.reusable = false // defensive iter.init(handle, iterToClone, opts) return iter } @@ -169,38 +176,65 @@ func (p *pebbleIterator) init(handle pebble.Reader, iterToClone cloneableIter, o p.inuse = true } -func (p *pebbleIterator) setOptions(opts IterOptions) { - // Overwrite any stale options from last time. - p.options = pebble.IterOptions{} - - if !opts.MinTimestampHint.IsEmpty() || !opts.MaxTimestampHint.IsEmpty() { - panic("iterator with timestamp hints cannot be reused") +// setBounds is called to change the bounds on a pebbleIterator. Note that +// this is not the first time that bounds will be passed to the underlying +// pebble.Iterator. The existing bounds are in p.options. +func (p *pebbleIterator) setBounds(lowerBound, upperBound roachpb.Key) { + // If the roachpb.Key bound is nil, the corresponding bound for the + // pebble.Iterator will also be nil. p.options contains the current bounds + // known to the pebble.Iterator. + boundsChanged := ((lowerBound == nil) != (p.options.LowerBound == nil)) || + ((upperBound == nil) != (p.options.UpperBound == nil)) + if !boundsChanged { + // The nil-ness is the same but the values may be different. + if lowerBound != nil { + // Both must be non-nil. We know that we've appended 0x00 to + // p.options.LowerBound, which must be ignored for this comparison. + if !bytes.Equal(p.options.LowerBound[:len(p.options.LowerBound)-1], lowerBound) { + boundsChanged = true + } + } + // If the preceding if-block has not already set boundsChanged=true, see + // if the upper bound has changed. + if !boundsChanged && upperBound != nil { + // Both must be non-nil. We know that we've appended 0x00 to + // p.options.UpperBound, which must be ignored for this comparison. + if !bytes.Equal(p.options.UpperBound[:len(p.options.UpperBound)-1], upperBound) { + boundsChanged = true + } + } } - if !opts.Prefix && len(opts.UpperBound) == 0 && len(opts.LowerBound) == 0 { - panic("iterator must set prefix or upper bound or lower bound") + if !boundsChanged { + // This noop optimization helps the underlying pebble.Iterator to optimize + // seeks. + return } - - p.prefix = opts.Prefix + // Set the bounds to nil, before we selectively change them. + p.options.LowerBound = nil + p.options.UpperBound = nil p.curBuf = (p.curBuf + 1) % 2 i := p.curBuf - if opts.LowerBound != nil { + if lowerBound != nil { // This is the same as - // p.options.LowerBound = EncodeKeyToBuf(p.lowerBoundBuf[i][:0], MVCCKey{Key: opts.LowerBound}) . - // or EngineKey{Key: opts.LowerBound}.EncodeToBuf(...). + // p.options.LowerBound = EncodeKeyToBuf(p.lowerBoundBuf[i][:0], MVCCKey{Key: lowerBound}) . + // or EngineKey{Key: lowerBound}.EncodeToBuf(...). // Since we are encoding keys with an empty version anyway, we can just // append the NUL byte instead of calling the above encode functions which // will do the same thing. - p.lowerBoundBuf[i] = append(p.lowerBoundBuf[i][:0], opts.LowerBound...) + p.lowerBoundBuf[i] = append(p.lowerBoundBuf[i][:0], lowerBound...) p.lowerBoundBuf[i] = append(p.lowerBoundBuf[i], 0x00) p.options.LowerBound = p.lowerBoundBuf[i] } - if opts.UpperBound != nil { + if upperBound != nil { // Same as above. - p.upperBoundBuf[i] = append(p.upperBoundBuf[i][:0], opts.UpperBound...) + p.upperBoundBuf[i] = append(p.upperBoundBuf[i][:0], upperBound...) p.upperBoundBuf[i] = append(p.upperBoundBuf[i], 0x00) p.options.UpperBound = p.upperBoundBuf[i] } p.iter.SetBounds(p.options.LowerBound, p.options.UpperBound) + if p.testingSetBoundsListener != nil { + p.testingSetBoundsListener.postSetBounds(p.options.LowerBound, p.options.UpperBound) + } } // Close implements the MVCCIterator interface. @@ -624,11 +658,22 @@ func findSplitKeyUsingIterator( return bestSplitKey, nil } -// SetUpperBound implements the MVCCIterator interface. +// SetUpperBound implements the MVCCIterator interface. Note that this is not +// the first time that bounds will be passed to the underlying +// pebble.Iterator. The existing bounds are in p.options. func (p *pebbleIterator) SetUpperBound(upperBound roachpb.Key) { if upperBound == nil { panic("SetUpperBound must not use a nil key") } + if p.options.UpperBound != nil { + // We know that we've appended 0x00 to p.options.UpperBound, which must be + // ignored for this comparison. + if bytes.Equal(p.options.UpperBound[:len(p.options.UpperBound)-1], upperBound) { + // Nothing to do. This noop optimization helps the underlying + // pebble.Iterator to optimize seeks. + return + } + } p.curBuf = (p.curBuf + 1) % 2 i := p.curBuf if p.options.LowerBound != nil { @@ -639,6 +684,9 @@ func (p *pebbleIterator) SetUpperBound(upperBound roachpb.Key) { p.upperBoundBuf[i] = append(p.upperBoundBuf[i], 0x00) p.options.UpperBound = p.upperBoundBuf[i] p.iter.SetBounds(p.options.LowerBound, p.options.UpperBound) + if p.testingSetBoundsListener != nil { + p.testingSetBoundsListener.postSetBounds(p.options.LowerBound, p.options.UpperBound) + } } // Stats implements the MVCCIterator interface. diff --git a/pkg/storage/pebble_test.go b/pkg/storage/pebble_test.go index 864dd8da893d..8e8d4a2d9e2f 100644 --- a/pkg/storage/pebble_test.go +++ b/pkg/storage/pebble_test.go @@ -186,6 +186,140 @@ func TestPebbleIterReuse(t *testing.T) { iter2.Close() } +type iterBoundsChecker struct { + t *testing.T + expectSetBounds bool + boundsSlices [2][]byte + boundsSlicesCopied [2][]byte +} + +func (ibc *iterBoundsChecker) postSetBounds(lower, upper []byte) { + require.True(ibc.t, ibc.expectSetBounds) + ibc.expectSetBounds = false + // The slices passed in the second from last SetBounds call + // must still be the same. + for i := range ibc.boundsSlices { + if ibc.boundsSlices[i] != nil { + if !bytes.Equal(ibc.boundsSlices[i], ibc.boundsSlicesCopied[i]) { + ibc.t.Fatalf("bound slice changed: expected: %x, actual: %x", + ibc.boundsSlicesCopied[i], ibc.boundsSlices[i]) + } + } + } + // Stash the bounds for later checking. + for i, bound := range [][]byte{lower, upper} { + ibc.boundsSlices[i] = bound + if bound != nil { + ibc.boundsSlicesCopied[i] = append(ibc.boundsSlicesCopied[i][:0], bound...) + } else { + ibc.boundsSlicesCopied[i] = nil + } + } +} + +func TestPebbleIterBoundSliceStabilityAndNoop(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + eng := createTestPebbleEngine().(*Pebble) + defer eng.Close() + iter := newPebbleIterator(eng.db, nil, IterOptions{UpperBound: roachpb.Key("foo")}) + defer iter.Close() + checker := &iterBoundsChecker{t: t} + iter.testingSetBoundsListener = checker + + tc := []struct { + expectSetBounds bool + setUpperOnly bool + lb roachpb.Key + ub roachpb.Key + }{ + { + // [nil, www) + expectSetBounds: true, + ub: roachpb.Key("www"), + }, + { + // [nil, www) + expectSetBounds: false, + ub: roachpb.Key("www"), + }, + { + // [nil, www) + expectSetBounds: false, + setUpperOnly: true, + ub: roachpb.Key("www"), + }, + { + // [ddd, www) + expectSetBounds: true, + lb: roachpb.Key("ddd"), + ub: roachpb.Key("www"), + }, + { + // [ddd, www) + expectSetBounds: false, + setUpperOnly: true, + ub: roachpb.Key("www"), + }, + { + // [ddd, xxx) + expectSetBounds: true, + setUpperOnly: true, + ub: roachpb.Key("xxx"), + }, + { + // [aaa, bbb) + expectSetBounds: true, + lb: roachpb.Key("aaa"), + ub: roachpb.Key("bbb"), + }, + { + // [ccc, ddd) + expectSetBounds: true, + lb: roachpb.Key("ccc"), + ub: roachpb.Key("ddd"), + }, + { + // [ccc, nil) + expectSetBounds: true, + lb: roachpb.Key("ccc"), + }, + { + // [ccc, nil) + expectSetBounds: false, + lb: roachpb.Key("ccc"), + }, + } + var lb, ub roachpb.Key + for _, c := range tc { + t.Run(fmt.Sprintf("%v", c), func(t *testing.T) { + checker.expectSetBounds = c.expectSetBounds + checker.t = t + if c.setUpperOnly { + iter.SetUpperBound(c.ub) + ub = c.ub + } else { + iter.setBounds(c.lb, c.ub) + lb, ub = c.lb, c.ub + } + require.False(t, checker.expectSetBounds) + for i, bound := range [][]byte{lb, ub} { + if (bound == nil) != (checker.boundsSlicesCopied[i] == nil) { + t.Fatalf("inconsistent nil %d", i) + } + if bound != nil { + expected := append([]byte(nil), bound...) + expected = append(expected, 0x00) + if !bytes.Equal(expected, checker.boundsSlicesCopied[i]) { + t.Fatalf("expected: %x, actual: %x", expected, checker.boundsSlicesCopied[i]) + } + } + } + }) + } +} + func makeMVCCKey(a string) MVCCKey { return MVCCKey{Key: []byte(a)} } diff --git a/vendor b/vendor index 797f916a0d9c..ac4d2de45c43 160000 --- a/vendor +++ b/vendor @@ -1 +1 @@ -Subproject commit 797f916a0d9c283c8f05e4d7af9790d4b165b834 +Subproject commit ac4d2de45c4319ec41f3552865fd0e8f60421def