diff --git a/connection/connection.go b/connection/connection.go index 5990b759e..9a2f9efb2 100644 --- a/connection/connection.go +++ b/connection/connection.go @@ -58,6 +58,7 @@ func SetMaxGRPCLogLength(characterCount int) { // The function tries to connect for 30 seconds, and returns an error if no connection has been established at that point. // The function automatically disables TLS and adds interceptor for logging of all gRPC messages at level 5. // If the metricsManager is 'nil', no metrics will be recorded on the gRPC calls. +// The function behaviour can be tweaked with options. // // For a connection to a Unix Domain socket, the behavior after // loosing the connection is configurable. The default is to @@ -72,9 +73,10 @@ func SetMaxGRPCLogLength(characterCount int) { // For other connections, the default behavior from gRPC is used and // loss of connection is not detected reliably. func Connect(address string, metricsManager metrics.CSIMetricsManager, options ...Option) (*grpc.ClientConn, error) { - options = append(options, withTimeout(time.Second*30)) + // Prepend default options + options = append([]Option{WithTimeout(time.Second * 30)}, options...) if metricsManager != nil { - options = append(options, withMetrics(metricsManager)) + options = append([]Option{WithMetrics(metricsManager)}, options...) } return connect(address, options) } @@ -114,15 +116,15 @@ func ExitOnConnectionLoss() func() bool { } } -// withTimeout adds a configurable timeout on the gRPC calls. -func withTimeout(timeout time.Duration) Option { +// WithTimeout adds a configurable timeout on the gRPC calls. +func WithTimeout(timeout time.Duration) Option { return func(o *options) { o.timeout = timeout } } -// withMetrics enables the recording of metrics on the gRPC calls with the provided CSIMetricsManager. -func withMetrics(metricsManager metrics.CSIMetricsManager) Option { +// WithMetrics enables the recording of metrics on the gRPC calls with the provided CSIMetricsManager. +func WithMetrics(metricsManager metrics.CSIMetricsManager) Option { return func(o *options) { o.metricsManager = metricsManager } diff --git a/connection/connection_test.go b/connection/connection_test.go index 7340a7108..ed5f143bb 100644 --- a/connection/connection_test.go +++ b/connection/connection_test.go @@ -223,7 +223,7 @@ func TestTimeout(t *testing.T) { startTime := time.Now() timeout := 5 * time.Second - conn, err := connect(path.Join(tmp, "no-such.sock"), []Option{withTimeout(timeout)}) + conn, err := connect(path.Join(tmp, "no-such.sock"), []Option{WithTimeout(timeout)}) endTime := time.Now() if assert.Error(t, err, "connection should fail") { assert.InEpsilon(t, timeout, endTime.Sub(startTime), 1, "connection timeout")