diff --git a/credentials/tls/certprovider/pemfile/watcher_test.go b/credentials/tls/certprovider/pemfile/watcher_test.go index d5ce5ab7e94d..e43cf7358eca 100644 --- a/credentials/tls/certprovider/pemfile/watcher_test.go +++ b/credentials/tls/certprovider/pemfile/watcher_test.go @@ -20,7 +20,7 @@ package pemfile import ( "context" - "crypto/x509" + "fmt" "io/ioutil" "math/big" "os" @@ -29,6 +29,7 @@ import ( "time" "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "google.golang.org/grpc/credentials/tls/certprovider" "google.golang.org/grpc/internal/grpctest" @@ -55,6 +56,30 @@ func Test(t *testing.T) { grpctest.RunSubTests(t, s{}) } +func compareKeyMaterial(got, want *certprovider.KeyMaterial) error { + // x509.Certificate type defines an Equal() method, but does not check for + // nil. This has been fixed in + // https://github.com/golang/go/commit/89865f8ba64ccb27f439cce6daaa37c9aa38f351, + // but this is only available starting go1.14. + // TODO(easwars): Remove this check once we remove support for go1.13. + if (got.Certs == nil && want.Certs != nil) || (want.Certs == nil && got.Certs != nil) { + return fmt.Errorf("keyMaterial certs = %+v, want %+v", got, want) + } + if !cmp.Equal(got.Certs, want.Certs, cmp.AllowUnexported(big.Int{})) { + return fmt.Errorf("keyMaterial certs = %+v, want %+v", got, want) + } + // x509.CertPool contains only unexported fields some of which contain other + // unexported fields. So usage of cmp.AllowUnexported() or + // cmpopts.IgnoreUnexported() does not help us much here. Also, the standard + // library does not provide a way to compare CertPool values. Comparing the + // subjects field of the certs in the CertPool seems like a reasonable + // approach. + if gotR, wantR := got.Roots.Subjects(), want.Roots.Subjects(); !cmp.Equal(gotR, wantR, cmpopts.EquateEmpty()) { + return fmt.Errorf("keyMaterial roots = %v, want %v", gotR, wantR) + } + return nil +} + // TestNewProvider tests the NewProvider() function with different inputs. func (s) TestNewProvider(t *testing.T) { tests := []struct { @@ -263,7 +288,7 @@ func (s) TestProvider_UpdateSuccess(t *testing.T) { if err != nil { t.Fatalf("provider.KeyMaterial() failed: %v", err) } - if cmp.Equal(km1, km2, cmp.AllowUnexported(big.Int{}, x509.CertPool{})) { + if err := compareKeyMaterial(km1, km2); err == nil { t.Fatal("expected provider to return new key material after update to underlying file") } @@ -279,7 +304,7 @@ func (s) TestProvider_UpdateSuccess(t *testing.T) { if err != nil { t.Fatalf("provider.KeyMaterial() failed: %v", err) } - if cmp.Equal(km2, km3, cmp.AllowUnexported(big.Int{}, x509.CertPool{})) { + if err := compareKeyMaterial(km2, km3); err == nil { t.Fatal("expected provider to return new key material after update to underlying file") } } @@ -363,7 +388,7 @@ func (s) TestProvider_UpdateSuccessWithSymlink(t *testing.T) { t.Fatalf("provider.KeyMaterial() failed: %v", err) } - if cmp.Equal(km1, km2, cmp.AllowUnexported(big.Int{}, x509.CertPool{})) { + if err := compareKeyMaterial(km1, km2); err == nil { t.Fatal("expected provider to return new key material after symlink update") } } @@ -403,8 +428,8 @@ func (s) TestProvider_UpdateFailure_ThenSuccess(t *testing.T) { if err != nil { t.Fatalf("provider.KeyMaterial() failed: %v", err) } - if !cmp.Equal(km1, km2, cmp.AllowUnexported(big.Int{}, x509.CertPool{})) { - t.Fatal("expected provider to not update key material") + if err := compareKeyMaterial(km1, km2); err != nil { + t.Fatalf("expected provider to not update key material: %v", err) } // Update the key file to match the cert file. @@ -418,7 +443,7 @@ func (s) TestProvider_UpdateFailure_ThenSuccess(t *testing.T) { if err != nil { t.Fatalf("provider.KeyMaterial() failed: %v", err) } - if cmp.Equal(km2, km3, cmp.AllowUnexported(big.Int{}, x509.CertPool{})) { + if err := compareKeyMaterial(km2, km3); err == nil { t.Fatal("expected provider to return new key material after update to underlying file") } }