diff --git a/proxy/proxy/client.go b/proxy/proxy/client.go index 84c992b2e..e240a9f22 100644 --- a/proxy/proxy/client.go +++ b/proxy/proxy/client.go @@ -575,7 +575,7 @@ func NewConnSrc(instance string, l net.Listener) <-chan Conn { return ch } -// InstanceVersionContext uses client cache to return instance version string. +// InstanceVersion uses client cache to return instance version string. // // Deprecated: Use Client.InstanceVersionContext instead. func (c *Client) InstanceVersion(instance string) (string, error) { @@ -586,7 +586,7 @@ func (c *Client) InstanceVersion(instance string) (string, error) { func (c *Client) InstanceVersionContext(ctx context.Context, instance string) (string, error) { _, _, version, err := c.cachedCfg(ctx, instance) if err != nil { - return "", nil + return "", err } return version, nil } diff --git a/proxy/proxy/client_test.go b/proxy/proxy/client_test.go index 0d380d60c..906f52feb 100644 --- a/proxy/proxy/client_test.go +++ b/proxy/proxy/client_test.go @@ -569,3 +569,47 @@ func TestConnectingWithInvalidConfig(t *testing.T) { t.Fatalf("wanted ErrUnexpectedFailure, got = %v", err) } } + +var ( + errLocal = errors.New("local failed") + errRemote = errors.New("remote failed") +) + +type failingCertSource struct{} + +func (cs failingCertSource) Local(instance string) (tls.Certificate, error) { + return tls.Certificate{}, errLocal +} + +func (cs failingCertSource) Remote(instance string) (cert *x509.Certificate, addr, name, version string, err error) { + return nil, "", "", "", errRemote +} + +func TestInstanceVersionContext(t *testing.T) { + testCases := []struct { + certSource CertSource + wantErr error + wantVersion string + }{ + { + certSource: newCertSource(&fakeCerts{}, forever), + wantErr: nil, + wantVersion: "fake version", + }, + { + certSource: failingCertSource{}, + wantErr: errLocal, + wantVersion: "", + }, + } + for _, tc := range testCases { + c := newClient(tc.certSource) + v, err := c.InstanceVersionContext(context.Background(), instance) + if v != tc.wantVersion { + t.Fatalf("want version = %v, got version = %v", tc.wantVersion, v) + } + if err != tc.wantErr { + t.Fatalf("want = %v, got = %v", tc.wantErr, err) + } + } +}