diff --git a/pkg/BUILD.bazel b/pkg/BUILD.bazel index 1745f9412a24..1efd59d04624 100644 --- a/pkg/BUILD.bazel +++ b/pkg/BUILD.bazel @@ -138,6 +138,7 @@ ALL_TESTS = [ "//pkg/cmd/release:release_test", "//pkg/cmd/roachprod-microbench:roachprod-microbench_test", "//pkg/cmd/roachtest/clusterstats:clusterstats_test", + "//pkg/cmd/roachtest/option:option_test", "//pkg/cmd/roachtest/roachtestutil/mixedversion:mixedversion_test", "//pkg/cmd/roachtest/roachtestutil:roachtestutil_test", "//pkg/cmd/roachtest/tests:tests_test", @@ -1041,6 +1042,7 @@ GO_TARGETS = [ "//pkg/cmd/roachtest/clusterstats:clusterstats_test", "//pkg/cmd/roachtest/grafana:grafana", "//pkg/cmd/roachtest/option:option", + "//pkg/cmd/roachtest/option:option_test", "//pkg/cmd/roachtest/registry:registry", "//pkg/cmd/roachtest/roachtestutil/clusterupgrade:clusterupgrade", "//pkg/cmd/roachtest/roachtestutil/mixedversion:mixedversion", diff --git a/pkg/cmd/roachtest/cluster.go b/pkg/cmd/roachtest/cluster.go index 214c77793d78..2916ad0d87c8 100644 --- a/pkg/cmd/roachtest/cluster.go +++ b/pkg/cmd/roachtest/cluster.go @@ -2365,6 +2365,13 @@ func (c *clusterImpl) ConnE( u.User = url.User(connOptions.User) dataSourceName = u.String() } + if len(connOptions.Options) > 0 { + vals := make(url.Values) + for k, v := range connOptions.Options { + vals.Add(k, v) + } + dataSourceName = dataSourceName + "&" + vals.Encode() + } db, err := gosql.Open("postgres", dataSourceName) if err != nil { return nil, err diff --git a/pkg/cmd/roachtest/option/BUILD.bazel b/pkg/cmd/roachtest/option/BUILD.bazel index 97be2b00fa75..6289cff1c520 100644 --- a/pkg/cmd/roachtest/option/BUILD.bazel +++ b/pkg/cmd/roachtest/option/BUILD.bazel @@ -1,5 +1,5 @@ load("//build/bazelutil/unused_checker:unused.bzl", "get_x_data") -load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") go_library( name = "option", @@ -18,4 +18,12 @@ go_library( ], ) +go_test( + name = "option_test", + srcs = ["connection_options_test.go"], + args = ["-test.timeout=295s"], + embed = [":option"], + deps = ["@com_github_stretchr_testify//require"], +) + get_x_data(name = "get_x_data") diff --git a/pkg/cmd/roachtest/option/connection_options.go b/pkg/cmd/roachtest/option/connection_options.go index 1633e8083823..9131df60ce15 100644 --- a/pkg/cmd/roachtest/option/connection_options.go +++ b/pkg/cmd/roachtest/option/connection_options.go @@ -10,9 +10,15 @@ package option +import ( + "fmt" + "time" +) + type ConnOption struct { User string TenantName string + Options map[string]string } func User(user string) func(*ConnOption) { @@ -26,3 +32,20 @@ func TenantName(tenantName string) func(*ConnOption) { option.TenantName = tenantName } } + +func ConnectionOption(key, value string) func(*ConnOption) { + return func(option *ConnOption) { + if len(option.Options) == 0 { + option.Options = make(map[string]string) + } + option.Options[key] = value + } +} + +func ConnectTimeout(t time.Duration) func(*ConnOption) { + sec := int64(t.Seconds()) + if sec < 1 { + sec = 1 + } + return ConnectionOption("connect_timeout", fmt.Sprintf("%d", sec)) +} diff --git a/pkg/cmd/roachtest/option/connection_options_test.go b/pkg/cmd/roachtest/option/connection_options_test.go new file mode 100644 index 000000000000..c853374954e0 --- /dev/null +++ b/pkg/cmd/roachtest/option/connection_options_test.go @@ -0,0 +1,52 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package option + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestFirstOptionCreatesMap(t *testing.T) { + var opts ConnOption + o := ConnectionOption("a", "b") + o(&opts) + require.NotNil(t, opts.Options) +} + +func TestTimeoutCalculation(t *testing.T) { + var opts ConnOption + for _, d := range []struct { + t time.Duration + o string + }{ + { + t: time.Second, + o: "1", + }, + { + t: time.Millisecond, + o: "1", + }, + { + t: time.Minute, + o: "60", + }, + } { + t.Run(d.t.String(), func(t *testing.T) { + o := ConnectTimeout(d.t) + o(&opts) + require.Equal(t, d.o, opts.Options["connect_timeout"]) + }) + } +} diff --git a/pkg/cmd/roachtest/tests/loss_of_quorum_recovery.go b/pkg/cmd/roachtest/tests/loss_of_quorum_recovery.go index 30060444378e..f79fd028c24b 100644 --- a/pkg/cmd/roachtest/tests/loss_of_quorum_recovery.go +++ b/pkg/cmd/roachtest/tests/loss_of_quorum_recovery.go @@ -264,7 +264,10 @@ func runRecoverLossOfQuorum(ctx context.Context, t test.Test, c cluster.Cluster, if ctx.Err() != nil { return &recoveryImpossibleError{testOutcome: restartFailed} } - db, err = c.ConnE(ctx, t.L(), 1) + // Note that conn doesn't actually connect, it just creates driver + // and prepares URL. Actual connection is done when statement is + // being executed. + db, err = c.ConnE(ctx, t.L(), 1, option.ConnectTimeout(15*time.Second)) if err == nil { break } @@ -485,7 +488,7 @@ func runHalfOnlineRecoverLossOfQuorum( if ctx.Err() != nil { return &recoveryImpossibleError{testOutcome: restartFailed} } - db, err = c.ConnE(ctx, t.L(), remaining[len(remaining)-1]) + db, err = c.ConnE(ctx, t.L(), remaining[len(remaining)-1], option.ConnectTimeout(15*time.Second)) if err == nil { break }