diff --git a/api/types_test.go b/api/types_test.go new file mode 100644 index 0000000..2860edc --- /dev/null +++ b/api/types_test.go @@ -0,0 +1,143 @@ +package api + +import ( + "os" + "testing" +) + +func TestOIDCConfig_Validate(t *testing.T) { + OIDCNotEnabled := false + OIDCIsEnabled := true + emptyFilename := "" + filenameThatDoesntExist := "doesnt-exist.yaml" + fullConfigFile, _ := os.CreateTemp("", "") + defer func(name string) { + err := os.Remove(name) + if err != nil { + t.Fatal(err) + } + }(fullConfigFile.Name()) + filenameComplete := fullConfigFile.Name() + yamlDataComplete := ` +issuer_url: https://example.com +client_id: client-id +client_secret: client-secret +` + _, _ = fullConfigFile.Write([]byte(yamlDataComplete)) + err := fullConfigFile.Close() + if err != nil { + return + } + emptyConfigFile, _ := os.CreateTemp("", "") + defer func(name string) { + err := os.Remove(name) + if err != nil { + t.Fatal(err) + } + }(emptyConfigFile.Name()) + filenameNotComplete := emptyConfigFile.Name() + yamlDataNotComplete := `` + _, _ = emptyConfigFile.Write([]byte(yamlDataNotComplete)) + err = emptyConfigFile.Close() + if err != nil { + return + } + type fields struct { + Enabled *bool + Filename *string + Attributes OIDCAttributes + } + tests := []struct { + name string + fields fields + wantErr bool + }{ + { + name: "should return nil if oidc is not enabled", + fields: fields{ + Enabled: &OIDCNotEnabled, + }, + wantErr: false, + }, + { + name: "should return error if oidc is enabled but no filename is provided", + fields: fields{ + Enabled: &OIDCIsEnabled, + Filename: &emptyFilename, + }, + wantErr: true, + }, + { + name: "should return err if oidc is enabled and filename is provided, but file does not exist", + fields: fields{ + Enabled: &OIDCIsEnabled, + Filename: &filenameThatDoesntExist, + }, + wantErr: true, + }, + { + name: "should return an error if oidc is enabled but the client id is not provided", + fields: fields{ + Enabled: &OIDCIsEnabled, + Filename: &filenameNotComplete, + Attributes: OIDCAttributes{ + IssuerURL: "https://test.com", + ClientID: "", + }, + }, + wantErr: true, + }, + { + name: "should return an error if oidc is enabled but the client secret is not provided", + fields: fields{ + Enabled: &OIDCIsEnabled, + Filename: &filenameNotComplete, + Attributes: OIDCAttributes{ + IssuerURL: "https://test.com", + ClientID: "test", + ClientSecret: "", + }, + }, + wantErr: true, + }, + { + name: "should return nil if oidc is enabled and all required fields are provided", + fields: fields{ + Enabled: &OIDCIsEnabled, + Filename: &filenameComplete, + Attributes: OIDCAttributes{ + IssuerURL: "https://test.com", + ClientID: "test", + ClientSecret: "test", + }, + }, + wantErr: false, + }, + { + name: "should return nil if oidc is enabled and all required fields are provided along with optional fields", + fields: fields{ + Enabled: &OIDCIsEnabled, + Filename: &filenameComplete, + Attributes: OIDCAttributes{ + IssuerURL: "https://test.com", + ClientID: "test", + ClientSecret: "test", + Audience: "test", + }, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &OIDCConfig{ + Enabled: tt.fields.Enabled, + Filename: tt.fields.Filename, + Attributes: tt.fields.Attributes, + } + if err := c.Validate(); (err != nil) != tt.wantErr { + t.Errorf("OIDCConfig.Validate() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} diff --git a/pkg/authtoken/authtoken_test.go b/pkg/authtoken/authtoken_test.go new file mode 100644 index 0000000..17056f7 --- /dev/null +++ b/pkg/authtoken/authtoken_test.go @@ -0,0 +1,349 @@ +package authtoken + +import ( + "crypto/tls" + "crypto/x509" + "log" + "net/http" + "net/url" + "reflect" + "testing" +) + +var ( + a byte = 116 + b byte = 101 + c byte = 115 + d byte = 116 + options = TokenVerifierOptions{ + Enabled: true, + URL: url.URL{ + Scheme: "https", + Host: "localhost:8080", + }, + CACertEnabled: true, + CACertRaw: CAPEMCertificateRawBytes{ + a, b, c, d, + }, + } + optionsWithoutCACert = TokenVerifierOptions{ + Enabled: true, + CACertEnabled: false, + CACertRaw: nil, + } + clientWithCACert, _ = defaultHTTPClient(options) +) + +func TestCAPEMCertificateRawBytes_ToByteSlice(t *testing.T) { + tests := []struct { + name string + c *CAPEMCertificateRawBytes + want []byte + }{ + { + name: "should return a byte slice", + c: &CAPEMCertificateRawBytes{ + a, b, c, d, + }, + want: []byte{116, 101, 115, 116}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.c.ToByteSlice(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("CAPEMCertificateRawBytes.ToByteSlice() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestToCAPEMCertificateRawBytes(t *testing.T) { + type args struct { + certRawBytes []byte + } + tests := []struct { + name string + args args + want CAPEMCertificateRawBytes + }{ + { + name: "should return a CAPEMCertificateRawBytes", + args: args{ + certRawBytes: []byte{116, 101, 115, 116}, + }, + want: CAPEMCertificateRawBytes{ + a, b, c, d, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := ToCAPEMCertificateRawBytes(tt.args.certRawBytes); !reflect.DeepEqual(got, tt.want) { + t.Errorf("ToCAPEMCertificateRawBytes() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestNewDefaultTokenVerifier(t *testing.T) { + type args struct { + options TokenVerifierOptions + } + tests := []struct { + name string + args args + want TokenVerifier + wantErr bool + }{ + { + name: "should return an error", + args: args{ + options: options, + }, + want: TokenVerifier{}, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := NewDefaultTokenVerifier(tt.args.options) + if (err != nil) != tt.wantErr { + t.Errorf("NewDefaultTokenVerifier() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("NewDefaultTokenVerifier() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_newTokenVerifier(t *testing.T) { + type args struct { + client http.Client + options TokenVerifierOptions + } + tests := []struct { + name string + args args + want TokenVerifier + }{ + { + name: "should return a TokenVerifier", + args: args{ + client: clientWithCACert, + options: options, + }, + want: TokenVerifier{ + client: clientWithCACert, + options: options, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := newTokenVerifier(tt.args.client, tt.args.options); !reflect.DeepEqual(got, tt.want) { + t.Errorf("newTokenVerifier() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestTokenVerifier_Enabled(t *testing.T) { + type fields struct { + client http.Client + options TokenVerifierOptions + } + tests := []struct { + name string + fields fields + want bool + }{ + { + name: "should return true", + fields: fields{ + client: clientWithCACert, + options: options, + }, + want: true, + }, + { + name: "should return false", + fields: fields{ + client: clientWithCACert, + options: TokenVerifierOptions{ + Enabled: false, + URL: url.URL{ + Scheme: "https", + Host: "localhost:8080", + }, + CACertEnabled: true, + CACertRaw: CAPEMCertificateRawBytes{ + a, b, c, d, + }, + }, + }, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tr := &TokenVerifier{ + client: tt.fields.client, + options: tt.fields.options, + } + if got := tr.Enabled(); got != tt.want { + t.Errorf("TokenVerifier.Enabled() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestTokenVerifier_GetAuthenticationToken(t *testing.T) { + type fields struct { + client http.Client + options TokenVerifierOptions + } + type args struct { + r *http.Request + } + tests := []struct { + name string + fields fields + args args + want string + }{ + { + name: "should return token", + fields: fields{ + client: clientWithCACert, + options: options, + }, + args: args{ + r: &http.Request{ + Header: http.Header{ + "Authorization": []string{"Bearer token"}, + }, + }, + }, + want: "Bearer token", + }, + { + name: "should return empty string", + fields: fields{ + client: clientWithCACert, + options: options, + }, + args: args{ + r: &http.Request{ + Header: http.Header{}, + }, + }, + want: "", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tr := &TokenVerifier{ + client: tt.fields.client, + options: tt.fields.options, + } + if got := tr.GetAuthenticationToken(tt.args.r); got != tt.want { + t.Errorf("TokenVerifier.GetAuthenticationToken() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestTokenVerifier_ValidateToken(t *testing.T) { + type fields struct { + client http.Client + options TokenVerifierOptions + } + type args struct { + url *url.URL + clusterID string + token string + } + tests := []struct { + name string + fields fields + args args + wantErr bool + }{ + { + name: "should return error", + fields: fields{ + client: clientWithCACert, + options: options, + }, + args: args{ + url: &url.URL{ + Scheme: "https", + Host: "localhost:8080", + }, + clusterID: "clusterID", + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tr := &TokenVerifier{ + client: tt.fields.client, + options: tt.fields.options, + } + if err := tr.ValidateToken(tt.args.url, tt.args.clusterID, tt.args.token); (err != nil) != tt.wantErr { + t.Errorf("TokenVerifier.ValidateToken() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func Test_defaultTLSConfig(t *testing.T) { + rootCAs, err := x509.SystemCertPool() + var tlsConfig tls.Config + if err != nil { + log.Print("failed to load system cert pool") + } + type args struct { + options TokenVerifierOptions + rootCA *x509.CertPool + } + tests := []struct { + name string + args args + want tls.Config + wantErr bool + }{ + { + name: "should return tls config if CACertEnabled is false", + args: args{ + options: optionsWithoutCACert, + rootCA: rootCAs, + }, + want: tlsConfig, + wantErr: false, + }, + { + name: "should return error", + args: args{ + options: options, + }, + want: tls.Config{}, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := defaultTLSConfig(tt.args.options) + if (err != nil) != tt.wantErr { + t.Errorf("defaultTLSConfig() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("defaultTLSConfig() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/pkg/remotewrite/remotewrite_test.go b/pkg/remotewrite/remotewrite_test.go new file mode 100644 index 0000000..f27f6b3 --- /dev/null +++ b/pkg/remotewrite/remotewrite_test.go @@ -0,0 +1,85 @@ +package remotewrite + +import ( + "io" + "net/http" + "testing" + + "github.com/golang/snappy" + "github.com/onsi/gomega" + "go.buf.build/protocolbuffers/go/prometheus/prometheus" + "google.golang.org/protobuf/proto" +) + +func TestDecodeWriteRequest(t *testing.T) { + rw := &prometheus.WriteRequest{ + Timeseries: []*prometheus.TimeSeries{ + { + Labels: []*prometheus.Label{ + { + Name: "test", + Value: "test", + }, + }, + Samples: []*prometheus.Sample{ + { + Value: 1, + Timestamp: 1, + }, + }, + }, + }, + } + r := &http.Request{} + g := gomega.NewGomegaWithT(t) + err := PopulateRequestBody(rw, r) + g.Expect(err).To(gomega.BeNil()) + + req, err := DecodeWriteRequest(r) + g.Expect(err).To(gomega.BeNil()) + + g.Expect(proto.Equal(rw, req)).To(gomega.BeTrue()) + + // Modify the HTTP request to have an invalid body + r.Body = io.NopCloser(io.MultiReader(r.Body, r.Body)) + writeReq, err := DecodeWriteRequest(r) + g.Expect(err).ToNot(gomega.BeNil()) + g.Expect(writeReq).To(gomega.BeNil()) +} + +func TestPopulateRequestBody(t *testing.T) { + rw := &prometheus.WriteRequest{ + Timeseries: []*prometheus.TimeSeries{ + { + Labels: []*prometheus.Label{ + { + Name: "test", + Value: "test", + }, + }, + Samples: []*prometheus.Sample{ + { + Value: 1, + Timestamp: 1, + }, + }, + }, + }, + } + r := &http.Request{} + g := gomega.NewGomegaWithT(t) + err := PopulateRequestBody(rw, r) + g.Expect(err).To(gomega.BeNil()) + + compressed, err := io.ReadAll(r.Body) + g.Expect(err).To(gomega.BeNil()) + + reqBuf, err := snappy.Decode(nil, compressed) + g.Expect(err).To(gomega.BeNil()) + + var req prometheus.WriteRequest + err = proto.Unmarshal(reqBuf, &req) + g.Expect(err).To(gomega.BeNil()) + + g.Expect(proto.Equal(rw, &req)).To(gomega.BeTrue()) +} diff --git a/pkg/remotewrite/validate_test.go b/pkg/remotewrite/validate_test.go index e144353..9bdd318 100644 --- a/pkg/remotewrite/validate_test.go +++ b/pkg/remotewrite/validate_test.go @@ -1,9 +1,10 @@ package remotewrite import ( + "testing" + "github.com/onsi/gomega" "go.buf.build/protocolbuffers/go/prometheus/prometheus" - "testing" ) var ( @@ -117,3 +118,114 @@ func TestFindClusterIDs(t *testing.T) { }) } } + +func TestValidateRequest(t *testing.T) { + type args struct { + remoteWriteRequest *prometheus.WriteRequest + } + tests := []struct { + name string + args args + want string + wantErr bool + }{ + { + name: "should return an error if there are multiple cluster IDs in the request", + args: args{ + remoteWriteRequest: reqWithMultipleClusterID, + }, + wantErr: true, + }, + { + name: "should return an error if there are no cluster IDs in the request", + args: args{ + remoteWriteRequest: req, + }, + wantErr: true, + }, + { + name: "should succeed if there is one cluster ID in the request", + args: args{ + remoteWriteRequest: reqWithOneClusterID, + }, + want: "cluster-1", + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ValidateRequest(tt.args.remoteWriteRequest) + if (err != nil) != tt.wantErr { + t.Errorf("ValidateRequest() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("ValidateRequest() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestIsMetadataRequest(t *testing.T) { + metadataReq := &prometheus.WriteRequest{ + Metadata: []*prometheus.MetricMetadata{ + { + MetricFamilyName: "test", + }, + }, + } + metadataAndTimeseriesReq := &prometheus.WriteRequest{ + Metadata: []*prometheus.MetricMetadata{ + { + MetricFamilyName: "test", + }, + }, + Timeseries: []*prometheus.TimeSeries{ + { + Labels: []*prometheus.Label{ + { + Name: "cluster_id", + Value: "cluster-1", + }, + }, + }, + }, + } + type args struct { + remoteWriteRequest *prometheus.WriteRequest + } + tests := []struct { + name string + args args + want bool + }{ + { + name: "should return false if the request is not a metadata request", + args: args{ + remoteWriteRequest: reqWithOneClusterID, + }, + want: false, + }, + { + name: "should return true if the request is a metadata request", + args: args{ + remoteWriteRequest: metadataReq, + }, + want: true, + }, + { + name: "should return false if the request is a metadata request and a timeseries request", + args: args{ + remoteWriteRequest: metadataAndTimeseriesReq, + }, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := IsMetadataRequest(tt.args.remoteWriteRequest); got != tt.want { + t.Errorf("IsMetadataRequest() = %v, want %v", got, tt.want) + } + }) + } +}