From 5b42b38dc6b0799e92dd0d8ce663e9e539572c6b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jakob=20=C3=98stergaard=20Jensen?= Date: Thu, 26 Sep 2024 08:38:53 +0200 Subject: [PATCH] Add unittests --- README.md | 2 - cmd/webhook/logging_test.go | 77 +++ cmd/webhook/main_test.go | 116 +++++ cmd/webhook/provider_test.go | 725 ++++++++++++++++++++++++++++ cmd/webhook/server_test.go | 37 ++ cmd/webhook/tidydns/metrics_test.go | 58 +++ cmd/webhook/tidydns/tidydns_test.go | 222 +++++++++ cmd/webhook/zoneprovider_test.go | 150 ++++++ 8 files changed, 1385 insertions(+), 2 deletions(-) create mode 100644 cmd/webhook/logging_test.go create mode 100644 cmd/webhook/main_test.go create mode 100644 cmd/webhook/provider_test.go create mode 100644 cmd/webhook/server_test.go create mode 100644 cmd/webhook/tidydns/metrics_test.go create mode 100644 cmd/webhook/tidydns/tidydns_test.go create mode 100644 cmd/webhook/zoneprovider_test.go diff --git a/README.md b/README.md index 90810f5..c37a441 100644 --- a/README.md +++ b/README.md @@ -82,7 +82,5 @@ go build cmd/webhook/ [tidydns-go](https://github.com/neticdk/tidydns-go) instead of the local tidydns package - So far the record types are A, AAAA and CNAME -- Needs some unit tests - More GitHub actions - - Unit tests - Relase pipeline diff --git a/cmd/webhook/logging_test.go b/cmd/webhook/logging_test.go new file mode 100644 index 0000000..1564c49 --- /dev/null +++ b/cmd/webhook/logging_test.go @@ -0,0 +1,77 @@ +package main + +import ( + "bytes" + "testing" +) + +func TestLoggingSetup(t *testing.T) { + tests := []struct { + name string + logFormat string + logLevel string + addSource bool + expectErr bool + expectText string + }{ + { + name: "JSON format with info level", + logFormat: "json", + logLevel: "info", + addSource: false, + expectErr: false, + expectText: `"level":"INFO"`, + }, + { + name: "Text format with debug level", + logFormat: "text", + logLevel: "debug", + addSource: false, + expectErr: false, + expectText: "level=DEBUG", + }, + { + name: "Invalid log level", + logFormat: "json", + logLevel: "invalid", + addSource: false, + expectErr: true, + expectText: `"level":"INFO"`, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var buf bytes.Buffer + out := &buf + + logger := loggingSetup(test.logFormat, test.logLevel, out, test.addSource) + + if test.expectErr { + if buf.Len() == 0 { + t.Errorf("Expected error log, got none") + } + } + + logger.Info("test log") + logOutput := buf.String() + + if !bytes.Contains([]byte(logOutput), []byte(test.expectText)) { + t.Errorf("Expected log output to contain %q, got %q", test.expectText, logOutput) + } + }) + } +} + +func TestLoggingSetupWithSource(t *testing.T) { + var buf bytes.Buffer + out := &buf + + logger := loggingSetup("text", "info", out, true) + logger.Info("test log with source") + + logOutput := buf.String() + if !bytes.Contains([]byte(logOutput), []byte("logging_test.go")) { + t.Errorf("Expected log output to contain source file info, got %q", logOutput) + } +} diff --git a/cmd/webhook/main_test.go b/cmd/webhook/main_test.go new file mode 100644 index 0000000..c6d0a5e --- /dev/null +++ b/cmd/webhook/main_test.go @@ -0,0 +1,116 @@ +package main + +import ( + "flag" + "os" + "testing" + "time" +) + +func TestParseConfig(t *testing.T) { + // Save the original command-line arguments and defer restoring them + origArgs := os.Args + defer func() { os.Args = origArgs }() + + // Save the original environment variables and defer restoring them + origTidyUser := os.Getenv("TIDYDNS_USER") + origTidyPass := os.Getenv("TIDYDNS_PASS") + defer func() { + os.Setenv("TIDYDNS_USER", origTidyUser) + os.Setenv("TIDYDNS_PASS", origTidyPass) + }() + + // Set up test cases + tests := []struct { + name string + args []string + envUser string + envPass string + expectedConfig *config + expectError bool + }{ + { + name: "default values", + args: []string{"cmd"}, + envUser: "testuser", + envPass: "testpass", + expectedConfig: &config{ + logLevel: "info", + logFormat: "text", + tidyEndpoint: "", + readTimeout: 5 * time.Second, + writeTimeout: 10 * time.Second, + zoneUpdateInterval: 10 * time.Minute, + tidyUsername: "testuser", + tidyPassword: "testpass", + }, + expectError: false, + }, + { + name: "custom values", + args: []string{"cmd", "--log-level=debug", "--log-format=json", "--tidydns-endpoint=http://example.com", "--read-timeout=3s", "--write-timeout=6s", "--zone-update-interval=15m"}, + envUser: "customuser", + envPass: "custompass", + expectedConfig: &config{ + logLevel: "debug", + logFormat: "json", + tidyEndpoint: "http://example.com", + readTimeout: 3 * time.Second, + writeTimeout: 6 * time.Second, + zoneUpdateInterval: 15 * time.Minute, + tidyUsername: "customuser", + tidyPassword: "custompass", + }, + expectError: false, + }, + { + name: "invalid duration", + args: []string{"cmd", "--zone-update-interval=invalid"}, + envUser: "testuser", + envPass: "testpass", + expectedConfig: nil, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Set command-line arguments + os.Args = tt.args + + // Set environment variables + os.Setenv("TIDYDNS_USER", tt.envUser) + os.Setenv("TIDYDNS_PASS", tt.envPass) + + // Reset the flag package to avoid conflicts + flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ExitOnError) + + // Call parseConfig + cfg, err := parseConfig() + + // Check for errors + if tt.expectError { + if err == nil { + t.Fatalf("expected an error but got none") + } + return + } else { + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + } + + // Compare the result with the expected config + if cfg.logLevel != tt.expectedConfig.logLevel || + cfg.logFormat != tt.expectedConfig.logFormat || + cfg.tidyEndpoint != tt.expectedConfig.tidyEndpoint || + cfg.readTimeout != tt.expectedConfig.readTimeout || + cfg.writeTimeout != tt.expectedConfig.writeTimeout || + cfg.zoneUpdateInterval != tt.expectedConfig.zoneUpdateInterval || + cfg.tidyUsername != tt.expectedConfig.tidyUsername || + cfg.tidyPassword != tt.expectedConfig.tidyPassword { + t.Errorf("expected config %+v, but got %+v", tt.expectedConfig, cfg) + } + }) + } +} diff --git a/cmd/webhook/provider_test.go b/cmd/webhook/provider_test.go new file mode 100644 index 0000000..b3953f5 --- /dev/null +++ b/cmd/webhook/provider_test.go @@ -0,0 +1,725 @@ +/* +Copyright 2024 Netic A/S. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package main + +import ( + "context" + "encoding/json" + "fmt" + "testing" + "time" + + "github.com/neticdk/external-dns-tidydns-webhook/cmd/webhook/tidydns" + "sigs.k8s.io/external-dns/endpoint" + "sigs.k8s.io/external-dns/plan" +) + +type mockTidyDNSClient struct { + zones []tidydns.Zone + createdRecords []tidydns.Record + deletedRecordIds []json.Number + err error +} + +func (m *mockTidyDNSClient) CreateRecord(zoneID json.Number, record *tidydns.Record) error { + if m.err != nil { + return m.err + } + + m.createdRecords = append(m.createdRecords, *record) + return nil +} + +func (m *mockTidyDNSClient) ListRecords(zoneID json.Number) ([]tidydns.Record, error) { + if m.err != nil { + return nil, m.err + } + + return m.createdRecords, nil +} + +func (m *mockTidyDNSClient) DeleteRecord(zoneID json.Number, recordID json.Number) error { + if m.err != nil { + return m.err + } + + m.deletedRecordIds = append(m.deletedRecordIds, recordID) + return nil +} + +func (m *mockTidyDNSClient) ListZones() ([]tidydns.Zone, error) { + return m.zones, m.err +} + +type mockZoneProvider struct{} + +func (m *mockZoneProvider) getZones() []tidydns.Zone { + return []tidydns.Zone{ + {Name: "example.com"}, + } +} + +func TestNewProvider(t *testing.T) { + tidy := &mockTidyDNSClient{} + zoneUpdateInterval := 10 * time.Minute + provider := newProvider(tidy, zoneUpdateInterval) + + if provider.tidy != tidy { + t.Errorf("expected tidy to be %v, got %v", tidy, provider.tidy) + } + + if provider.zoneProvider == nil { + t.Error("expected zoneProvider to be initialized") + } +} + +func TestGetDomainFilter(t *testing.T) { + tidy := &mockTidyDNSClient{} + zoneProvider := &mockZoneProvider{} + provider := &tidyProvider{ + tidy: tidy, + zoneProvider: zoneProvider, + } + + domainFilter := provider.GetDomainFilter() + expectedDomains := []string{"example.com"} + + for _, domain := range expectedDomains { + if !domainFilter.Match(domain) { + t.Errorf("expected domain filter to match %s", domain) + } + } +} + +func TestRecords(t *testing.T) { + tidy := &mockTidyDNSClient{} + zoneProvider := &mockZoneProvider{} + provider := &tidyProvider{ + tidy: tidy, + zoneProvider: zoneProvider, + } + + tests := []struct { + name string + mockRecords []tidydns.Record + expectedError bool + expectedResult []*Endpoint + }{ + { + name: "Valid A record", + mockRecords: []tidydns.Record{ + { + ID: "1", + Type: "A", + Name: "test", + Destination: "1.2.3.4", + TTL: json.Number("300"), + ZoneName: "example.com", + ZoneID: "1", + }, + }, + expectedError: false, + expectedResult: []*Endpoint{ + endpoint.NewEndpointWithTTL("test.example.com", "A", 300, "1.2.3.4"), + }, + }, + { + name: "Fail to list records", + mockRecords: []tidydns.Record{ + { + ID: "2", + Type: "A", + Name: "test", + Destination: "1.2.3.4", + TTL: json.Number("300"), + ZoneName: "example.com", + ZoneID: "1", + }, + }, + expectedError: true, + expectedResult: nil, + }, + { + name: "Invalid TTL", + mockRecords: []tidydns.Record{ + { + ID: "2", + Type: "A", + Name: "invalid-ttl", + Destination: "1.2.3.4", + TTL: json.Number("300.2"), + ZoneName: "example.com", + ZoneID: "1", + }, + }, + expectedError: false, + expectedResult: []*Endpoint{}, + }, + { + name: "Multiple records", + mockRecords: []tidydns.Record{ + { + ID: "3", + Type: "A", + Name: "multi", + Destination: "1.2.3.4", + TTL: json.Number("300"), + ZoneName: "example.com", + ZoneID: "1", + }, + { + ID: "4", + Type: "A", + Name: "multi", + Destination: "5.6.7.8", + TTL: json.Number("300"), + ZoneName: "example.com", + ZoneID: "1", + }, + }, + expectedError: false, + expectedResult: []*Endpoint{ + endpoint.NewEndpointWithTTL("multi.example.com", "A", 300, "1.2.3.4", "5.6.7.8"), + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + tidy.createdRecords = test.mockRecords + if test.expectedError { + tidy.err = fmt.Errorf("list records error") + } else { + tidy.err = nil + } + + records, err := provider.Records(context.Background()) + + if test.expectedError { + if err == nil { + t.Fatalf("expected error, got none") + } + return + } + + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + if len(records) != len(test.expectedResult) { + fmt.Println(records) + t.Fatalf("expected %d records, got %d", len(test.expectedResult), len(records)) + } + + for i, record := range records { + if record.DNSName != test.expectedResult[i].DNSName || record.RecordType != test.expectedResult[i].RecordType || record.RecordTTL != test.expectedResult[i].RecordTTL || len(record.Targets) != len(test.expectedResult[i].Targets) || record.Targets[0] != test.expectedResult[i].Targets[0] { + t.Errorf("expected %v, got %v", test.expectedResult[i], record) + } + } + }) + } +} + +func TestAdjustEndpoints(t *testing.T) { + // Labels are not added by the constructor, so we add them manually after + // the fact and use them as test parameters below. + ARecWithLabels := endpoint.NewEndpointWithTTL("example.com", "A", 100, "1.2.3.4") + ARecWithLabels.Labels = map[string]string{"label": "value", "label2": "value2"} + + TXTRecWithLabels := endpoint.NewEndpointWithTTL("example.com", "TXT", 300, "\"v=spf1 include:example.com ~all\"") + TXTRecWithLabels.Labels = map[string]string{"label": "value", "label2": "value2"} + + tests := []struct { + name string + endpoints []*Endpoint + expected []*Endpoint + }{ + { + name: "Adjust TTL and remove labels", + endpoints: []*Endpoint{ARecWithLabels, TXTRecWithLabels}, + expected: []*Endpoint{ + endpoint.NewEndpointWithTTL("example.com", "A", 300, "1.2.3.4"), + endpoint.NewEndpointWithTTL("example.com", "TXT", 300, "\"v=spf1 include:example.com ~all\""), + }, + }, + { + name: "Adjust TTL to minimum and encode punycode", + endpoints: []*Endpoint{ + endpoint.NewEndpointWithTTL("xn--exmple-cua.com", "A", 100, "1.2.3.4"), + }, + expected: []*Endpoint{ + endpoint.NewEndpointWithTTL("xn--exmple-cua.com", "A", 300, "1.2.3.4"), + }, + }, + { + name: "No adjustment needed", + endpoints: []*Endpoint{ + endpoint.NewEndpointWithTTL("example.com", "A", 300, "1.2.3.4"), + }, + expected: []*Endpoint{ + endpoint.NewEndpointWithTTL("example.com", "A", 300, "1.2.3.4"), + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + tidy := &mockTidyDNSClient{} + provider := &tidyProvider{ + tidy: tidy, + zoneProvider: &mockZoneProvider{}, + } + + result, err := provider.AdjustEndpoints(test.endpoints) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + if len(result) != len(test.expected) { + t.Fatalf("expected %d endpoints, got %d", len(test.expected), len(result)) + } + + for i, endpoint := range result { + if endpoint.DNSName != test.expected[i].DNSName || endpoint.RecordType != test.expected[i].RecordType || endpoint.RecordTTL != test.expected[i].RecordTTL || len(endpoint.Targets) != len(test.expected[i].Targets) || endpoint.Targets[0] != test.expected[i].Targets[0] { + t.Errorf("expected %v, got %v", test.expected[i], endpoint) + } + } + }) + } +} + +func TestApplyChanges(t *testing.T) { + tidy := &mockTidyDNSClient{} + zoneProvider := &mockZoneProvider{} + provider := &tidyProvider{ + tidy: tidy, + zoneProvider: zoneProvider, + } + + tests := []struct { + name string + changes *plan.Changes + expectErr bool + }{ + { + name: "Create record", + expectErr: false, + changes: &plan.Changes{ + Create: []*Endpoint{ + endpoint.NewEndpointWithTTL("create.example.com", "A", 300, "1.2.3.4"), + }, + }, + }, + { + name: "Delete record", + expectErr: false, + changes: &plan.Changes{ + Delete: []*Endpoint{ + endpoint.NewEndpointWithTTL("delete.example.com", "A", 300, "1.2.3.4"), + }, + }, + }, + { + name: "Update record", + expectErr: false, + changes: &plan.Changes{ + UpdateOld: []*Endpoint{ + endpoint.NewEndpointWithTTL("update.example.com", "A", 300, "1.2.3.4"), + }, + UpdateNew: []*Endpoint{ + endpoint.NewEndpointWithTTL("update.example.com", "A", 300, "5.6.7.8"), + }, + }, + }, + { + name: "Fail updating record", + expectErr: true, + changes: &plan.Changes{ + UpdateOld: []*Endpoint{ + endpoint.NewEndpointWithTTL("update.example.com", "A", 300, "1.2.3.4"), + }, + UpdateNew: []*Endpoint{ + endpoint.NewEndpointWithTTL("update.example.com", "A", 300, "5.6.7.8"), + }, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.expectErr { + tidy.err = fmt.Errorf("apply changes error") + } else { + tidy.err = nil + } + + err := provider.ApplyChanges(context.Background(), test.changes) + if !test.expectErr && err != nil { + t.Fatalf("expected no error, got %v", err) + } + }) + } +} + +func TestDeleteEndpoint(t *testing.T) { + allRecords := []tidydns.Record{ + { + ID: "1", + Type: "A", + Name: "delete", + Destination: "1.2.3.4", + TTL: json.Number("300"), + ZoneName: "example.com", + ZoneID: "1", + }, + { + ID: "2", + Type: "CNAME", + Name: "www", + Destination: "example.com", + TTL: json.Number("300"), + ZoneName: "example.com", + ZoneID: "1", + }, + } + + tests := []struct { + name string + encounterErr error + endpoint *Endpoint + expected []json.Number + }{ + { + name: "Delete A record", + encounterErr: nil, + endpoint: endpoint.NewEndpointWithTTL("delete.example.com", "A", 300, "1.2.3.4"), + expected: []json.Number{ + json.Number("1"), + }, + }, + { + name: "Delete CNAME record", + encounterErr: nil, + endpoint: endpoint.NewEndpointWithTTL("www.example.com", "CNAME", 300, "example.com"), + expected: []json.Number{ + json.Number("2"), + }, + }, + { + name: "Delete non-existing record", + encounterErr: nil, + endpoint: endpoint.NewEndpointWithTTL("nonexistent.example.com", "A", 300, "1.2.3.4"), + expected: []json.Number{}, + }, + { + name: "Error on delete", + encounterErr: fmt.Errorf("delete record error"), + endpoint: endpoint.NewEndpointWithTTL("delete.example.com", "A", 300, "1.2.3.4"), + expected: []json.Number{}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + tidy := &mockTidyDNSClient{ + err: test.encounterErr, + } + + provider := &tidyProvider{ + tidy: tidy, + zoneProvider: &mockZoneProvider{}, + } + + provider.deleteEndpoint(allRecords, test.endpoint) + + if len(tidy.deletedRecordIds) != len(test.expected) { + t.Fatalf("expected %d records to be deleted, got %d", len(test.expected), len(tidy.deletedRecordIds)) + } + + for i, recordId := range tidy.deletedRecordIds { + if recordId != test.expected[i] { + t.Errorf("expected record ID %s, got %s", test.expected[i], tidy.deletedRecordIds[i]) + } + } + }) + } +} + +func TestCreateRecord(t *testing.T) { + zones := []tidydns.Zone{ + {Name: "example.com", ID: "1"}, + {Name: "example.org", ID: "2"}, + } + + tests := []struct { + name string + zones []tidydns.Zone + encounterErr error + endpoint *Endpoint + expected []tidydns.Record + }{ + { + name: "Create A record", + zones: zones, + encounterErr: nil, + endpoint: endpoint.NewEndpointWithTTL("create.example.com", "A", 300, "1.2.3.4"), + expected: []tidydns.Record{ + { + Type: "A", + Name: "create", + Destination: "1.2.3.4", + TTL: json.Number("300"), + }, + }, + }, + { + name: "Error on create A record", + zones: zones, + encounterErr: fmt.Errorf("create record error"), + endpoint: endpoint.NewEndpointWithTTL("create.example.com", "A", 300, "1.2.3.4"), + expected: []tidydns.Record{}, + }, + { + name: "Create CNAME record", + zones: zones, + encounterErr: nil, + endpoint: endpoint.NewEndpointWithTTL("www.example.com", "CNAME", 300, "example.com"), + expected: []tidydns.Record{ + { + Type: "CNAME", + Name: "www", + Destination: "example.com.", + TTL: json.Number("300"), + }, + }, + }, + { + name: "Create TXT record", + zones: zones, + encounterErr: nil, + endpoint: endpoint.NewEndpointWithTTL("txt.example.com", "TXT", 300, "\"v=spf1 include:example.com ~all\""), + expected: []tidydns.Record{ + { + Type: "TXT", + Name: "txt", + Destination: "v=spf1 include:example.com ~all", + TTL: json.Number("300"), + }, + }, + }, + { + name: "Create record with TTL below minimum", + zones: zones, + encounterErr: nil, + endpoint: endpoint.NewEndpointWithTTL("lowttl.example.com", "A", 100, "1.2.3.4"), + expected: []tidydns.Record{ + { + Type: "A", + Name: "lowttl", + Destination: "1.2.3.4", + TTL: json.Number("300"), + }, + }, + }, + { + name: "Create record with no zones", + zones: []tidydns.Zone{}, + encounterErr: nil, + endpoint: endpoint.NewEndpointWithTTL("nozone.example.com", "A", 300, "1.2.3.4"), + expected: []tidydns.Record{}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + tidy := &mockTidyDNSClient{ + err: test.encounterErr, + } + + provider := &tidyProvider{ + tidy: tidy, + zoneProvider: &mockZoneProvider{}, + } + + provider.createRecord(test.zones, test.endpoint) + + if len(tidy.createdRecords) != len(test.expected) { + t.Fatalf("expected %d records to be created, got %d", len(test.expected), len(tidy.createdRecords)) + } + + for i, record := range tidy.createdRecords { + if record.Type != test.expected[i].Type || record.Name != test.expected[i].Name || record.Destination != test.expected[i].Destination || record.TTL != test.expected[i].TTL { + t.Errorf("expected record %+v, got %+v", test.expected[i], record) + } + } + }) + } +} + +func TestParseTidyRecord(t *testing.T) { + tests := []struct { + name string + record tidyRecord + expected *Endpoint + }{ + { + name: "A record", + record: tidyRecord{ + ID: "1", + Type: "A", + Name: "example", + Description: "Test A record", + Destination: "1.2.3.4", + TTL: "300", + ZoneName: "example.com", + ZoneID: "1", + }, + expected: endpoint.NewEndpointWithTTL("example.example.com", "A", 300, "1.2.3.4"), + }, + { + name: "CNAME record", + record: tidyRecord{ + ID: "2", + Type: "CNAME", + Name: "www", + Description: "Test CNAME record", + Destination: "example.com.", + TTL: "300", + ZoneName: "example.com", + ZoneID: "1", + }, + expected: endpoint.NewEndpointWithTTL("www.example.com", "CNAME", 300, "example.com"), + }, + { + name: "TXT record", + record: tidyRecord{ + ID: "3", + Type: "TXT", + Name: "txt", + Description: "Test TXT record", + Destination: "\"v=spf1 include:example.com ~all\"", + TTL: "300", + ZoneName: "example.com", + ZoneID: "1", + }, + expected: endpoint.NewEndpointWithTTL("txt.example.com", "TXT", 300, "\"v=spf1 include:example.com ~all\""), + }, + { + name: "Invalid TTL", + record: tidyRecord{ + ID: "4", + Type: "A", + Name: "invalid-ttl", + Description: "Test invalid TTL", + Destination: "1.2.3.4", + TTL: "invalid", + ZoneName: "example.com", + ZoneID: "1", + }, + expected: nil, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + result := parseTidyRecord(&test.record) + if result == nil && test.expected != nil { + t.Errorf("expected %v, got nil", test.expected) + } else if result != nil && test.expected == nil { + t.Errorf("expected nil, got %v", result) + } else if result != nil && test.expected != nil { + if result.DNSName != test.expected.DNSName || result.RecordType != test.expected.RecordType || result.RecordTTL != test.expected.RecordTTL || len(result.Targets) != len(test.expected.Targets) || result.Targets[0] != test.expected.Targets[0] { + t.Errorf("expected %v, got %v", test.expected, result) + } + } + }) + } +} + +func TestTidyNameToFQDN(t *testing.T) { + tests := []struct { + name string + inputName string + inputZone string + expected string + }{ + {"Root domain", ".", "example.com", "example.com"}, + {"Subdomain", "sub", "example.com", "sub.example.com"}, + {"Root domain with dot", ".", "example.org", "example.org"}, + {"Subdomain with dot", "sub", "example.org", "sub.example.org"}, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + result := tidyNameToFQDN(test.inputName, test.inputZone) + if result != test.expected { + t.Errorf("expected %s, got %s", test.expected, result) + } + }) + } +} + +func TestClampTTL(t *testing.T) { + tests := []struct { + name string + inputTTL int + expected int + }{ + {"TTL below minimum", 100, 300}, + {"TTL at minimum", 300, 300}, + {"TTL above minimum", 600, 600}, + {"TTL zero", 0, 0}, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + result := clampTTL(test.inputTTL) + if result != test.expected { + t.Errorf("expected %d, got %d", test.expected, result) + } + }) + } +} + +func TestTidyfyName(t *testing.T) { + zones := []tidydns.Zone{ + {Name: "example.com", ID: "1"}, + {Name: "example.org", ID: "2"}, + } + + tests := []struct { + name string + fqdn string + expected string + zoneID json.Number + }{ + {"Root domain", "example.com", ".", "1"}, + {"Subdomain", "sub.example.com", "sub", "1"}, + {"Root domain org", "example.org", ".", "2"}, + {"Subdomain org", "sub.example.org", "sub", "2"}, + {"Non-matching domain", "example.net", "", "0"}, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + result, zoneID := tidyfyName(zones, test.fqdn) + if result != test.expected || zoneID != test.zoneID { + t.Errorf("expected (%s, %s), got (%s, %s)", test.expected, test.zoneID, result, zoneID) + } + }) + } +} diff --git a/cmd/webhook/server_test.go b/cmd/webhook/server_test.go new file mode 100644 index 0000000..a1bb6a3 --- /dev/null +++ b/cmd/webhook/server_test.go @@ -0,0 +1,37 @@ +/* +Copyright 2024 Netic A/S. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package main + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +func TestHealthz(t *testing.T) { + req, err := http.NewRequest("GET", "/healthz", nil) + if err != nil { + t.Fatalf("Could not create request: %v", err) + } + + rec := httptest.NewRecorder() + healthz(rec, req) + + if rec.Code != http.StatusOK { + t.Errorf("Expected status OK; got %v", rec.Code) + } +} diff --git a/cmd/webhook/tidydns/metrics_test.go b/cmd/webhook/tidydns/metrics_test.go new file mode 100644 index 0000000..06787c0 --- /dev/null +++ b/cmd/webhook/tidydns/metrics_test.go @@ -0,0 +1,58 @@ +/* +Copyright 2024 Netic A/S. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package tidydns + +import ( + "fmt" + "testing" + + "go.opentelemetry.io/otel/metric" + "go.opentelemetry.io/otel/metric/noop" +) + +func TestCounterProvider(t *testing.T) { + meter := noop.NewMeterProvider().Meter("test") + + counter, err := counterProvider(meter, "test_counter", "Test counter description") + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + + if counter == nil { + t.Fatalf("Expected a valid counter function, got nil") + } + + // Test the counter function + counter("GET", "/test", 200) +} + +type badMeter struct { + noop.Meter +} + +func (m *badMeter) Int64Counter(name string, options ...metric.Int64CounterOption) (metric.Int64Counter, error) { + return nil, fmt.Errorf("error") +} + +func TestCounterProviderError(t *testing.T) { + meter := &badMeter{} + _, err := counterProvider(meter, "test_counter", "Test counter description") + + if err == nil { + t.Fatalf("Expected an error, got nil") + } +} diff --git a/cmd/webhook/tidydns/tidydns_test.go b/cmd/webhook/tidydns/tidydns_test.go new file mode 100644 index 0000000..4aaf1e4 --- /dev/null +++ b/cmd/webhook/tidydns/tidydns_test.go @@ -0,0 +1,222 @@ +/* +Copyright 2024 Netic A/S. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package tidydns + +import ( + "errors" + "net/http" + "net/http/httptest" + "testing" + "time" + + "go.opentelemetry.io/otel/metric/noop" +) + +func mockCounter(method, url string, code int) { + // Do nothings +} + +func TestNewTidyDnsClient(t *testing.T) { + meter := noop.NewMeterProvider().Meter("test") + client, err := NewTidyDnsClient("http://example.com", "user", "pass", (10 * time.Second), meter) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + + if client == nil { + t.Fatalf("Expected client, got nil") + } +} + +func TestNewTidyDnsClientErrBadMeter(t *testing.T) { + meter := &badMeter{} + _, err := NewTidyDnsClient("http://example.com", "user", "pass", (10 * time.Second), meter) + if err == nil { + t.Fatalf("Expected an error, got nil") + } +} + +func TestListZones(t *testing.T) { + handler := func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(`[{"id": "1", "name": "zone1"}, {"id": "2", "name": "zone2"}]`)) + } + server := httptest.NewServer(http.HandlerFunc(handler)) + defer server.Close() + + client := &tidyDNSClient{ + client: server.Client(), + baseURL: server.URL, + username: "user", + password: "pass", + counter: mockCounter, + } + + zones, err := client.ListZones() + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + + if len(zones) != 2 { + t.Fatalf("Expected 2 zones, got %d", len(zones)) + } +} + +func TestCreateRecord(t *testing.T) { + handler := func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + } + server := httptest.NewServer(http.HandlerFunc(handler)) + defer server.Close() + + client := &tidyDNSClient{ + client: server.Client(), + baseURL: server.URL, + username: "user", + password: "pass", + counter: mockCounter, + } + + record := &Record{ + Type: "A", + Name: "test", + Description: "Test record", + Destination: "1.2.3.4", + TTL: "300", + } + + err := client.CreateRecord("1", record) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } +} + +func TestCreateRecordFailure(t *testing.T) { + client := &tidyDNSClient{} + record := &Record{ + Type: "a", + Name: "test", + Description: "Test record", + Destination: "1.2.3.4", + TTL: "300", + } + + err := client.CreateRecord("1", record) + if err == nil { + t.Fatalf("Expected error, got nil") + } +} + +func TestListRecords(t *testing.T) { + handler := func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(`[{"id": "1", "type_name": "A", "name": "test", "description": "Test record", "destination": "1.2.3.4", "ttl": "300", "zone_name": "example.com", "zone_id": "1"}]`)) + } + server := httptest.NewServer(http.HandlerFunc(handler)) + defer server.Close() + + client := &tidyDNSClient{ + client: server.Client(), + baseURL: server.URL, + username: "user", + password: "pass", + counter: mockCounter, + } + + records, err := client.ListRecords("1") + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + + if len(records) != 1 { + t.Fatalf("Expected 1 record, got %d", len(records)) + } +} + +func TestDeleteRecord(t *testing.T) { + handler := func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + } + server := httptest.NewServer(http.HandlerFunc(handler)) + defer server.Close() + + client := &tidyDNSClient{ + client: server.Client(), + baseURL: server.URL, + username: "user", + password: "pass", + counter: mockCounter, + } + + err := client.DeleteRecord("1", "1") + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } +} + +func TestRequestErrBadRequest(t *testing.T) { + client := &tidyDNSClient{ + baseURL: "http://example.com", + } + + err := client.request("GET", "/tes\t", nil, nil) + if err == nil { + t.Fatalf("Expected error, got nil") + } +} + +func TestRequestErrorHandling(t *testing.T) { + handler := func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + } + server := httptest.NewServer(http.HandlerFunc(handler)) + defer server.Close() + + client := &tidyDNSClient{ + client: server.Client(), + baseURL: server.URL, + username: "user", + password: "pass", + counter: mockCounter, + } + + err := client.request("GET", "/test", nil, nil) + if err == nil { + t.Fatalf("Expected error, got nil") + } +} + +func TestEncodeRecordType(t *testing.T) { + tests := []struct { + input string + expected RecordType + err error + }{ + {"AAAA", RecordTypeA, nil}, + {"A", RecordTypeA, nil}, + {"CNAME", RecordTypeCNAME, nil}, + {"TXT", RecordTypeTXT, nil}, + {"UNKNOWN", RecordType(0), errors.New("unmapped record type UNKNOWN")}, + } + + for _, test := range tests { + result, err := encodeRecordType(test.input) + if result != test.expected || (err != nil && err.Error() != test.err.Error()) { + t.Errorf("Expected %v and %v, got %v and %v", test.expected, test.err, result, err) + } + } +} diff --git a/cmd/webhook/zoneprovider_test.go b/cmd/webhook/zoneprovider_test.go new file mode 100644 index 0000000..337e622 --- /dev/null +++ b/cmd/webhook/zoneprovider_test.go @@ -0,0 +1,150 @@ +/* +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package main + +import ( + "errors" + "testing" + "time" + + "github.com/neticdk/external-dns-tidydns-webhook/cmd/webhook/tidydns" +) + +func TestNewZoneProvider(t *testing.T) { + mockZones := []tidydns.Zone{ + {Name: "zone1"}, + {Name: "zone2"}, + } + + mockClient := &mockTidyDNSClient{zones: mockZones} + provider := newZoneProvider(mockClient, (10 * time.Minute)) + + zones := provider.getZones() + if len(zones) != len(mockZones) { + t.Fatalf("Expected %d zones, got %d", len(mockZones), len(zones)) + } + + for i, zone := range zones { + if zone.Name != mockZones[i].Name { + t.Errorf("Expected zone name %s, got %s", mockZones[i].Name, zone.Name) + } + } +} + +func TestZoneProviderUpdateWithError(t *testing.T) { + initialZones := []tidydns.Zone{ + {Name: "zone1"}, + } + + mockClient := &mockTidyDNSClient{zones: initialZones} + provider := newZoneProvider(mockClient, (1 * time.Second)) + + // Initial zones check + zones := provider.getZones() + if len(zones) != len(initialZones) { + t.Fatalf("Expected %d initial zones, got %d", len(initialZones), len(zones)) + } + + for i, zone := range zones { + if zone.Name != initialZones[i].Name { + t.Errorf("Expected initial zone name %s, got %s", initialZones[i].Name, zone.Name) + } + } + + // Introduce an error in the mock client + mockClient.err = errors.New("mock update error") + + // Wait for the update interval to pass + time.Sleep(2 * time.Second) + + // Check zones after error + zones = provider.getZones() + if len(zones) != len(initialZones) { + t.Fatalf("Expected %d zones after error, got %d", len(initialZones), len(zones)) + } + + for i, zone := range zones { + if zone.Name != initialZones[i].Name { + t.Errorf("Expected zone name %s after error, got %s", initialZones[i].Name, zone.Name) + } + } +} + +func TestZoneProviderUpdateWithNewZones(t *testing.T) { + initialZones := []tidydns.Zone{ + {Name: "zone1"}, + } + + updatedZones := []tidydns.Zone{ + {Name: "zone1"}, + {Name: "zone2"}, + } + + mockClient := &mockTidyDNSClient{zones: initialZones} + provider := newZoneProvider(mockClient, (1 * time.Second)) + + // Initial zones check + zones := provider.getZones() + if len(zones) != len(initialZones) { + t.Fatalf("Expected %d initial zones, got %d", len(initialZones), len(zones)) + } + + for i, zone := range zones { + if zone.Name != initialZones[i].Name { + t.Errorf("Expected initial zone name %s, got %s", initialZones[i].Name, zone.Name) + } + } + + // Update the zones in the mock client + mockClient.zones = updatedZones + + // Wait for the update interval to pass + time.Sleep(2 * time.Second) + + // Check zones after update + zones = provider.getZones() + if len(zones) != len(updatedZones) { + t.Fatalf("Expected %d zones after update, got %d", len(updatedZones), len(zones)) + } + + for i, zone := range zones { + if zone.Name != updatedZones[i].Name { + t.Errorf("Expected zone name %s after update, got %s", updatedZones[i].Name, zone.Name) + } + } +} + +func TestZoneProviderErrorHandling(t *testing.T) { + mockClient := &mockTidyDNSClient{err: errors.New("mock error")} + + defer func() { + if r := recover(); r == nil { + t.Errorf("Expected panic due to error in ListZones") + } + }() + + newZoneProvider(mockClient, (10 * time.Minute)) +} + +func TestZoneProviderNoZones(t *testing.T) { + mockClient := &mockTidyDNSClient{zones: []tidydns.Zone{}} + + provider := newZoneProvider(mockClient, (10 * time.Minute)) + + zones := provider.getZones() + if len(zones) != 0 { + t.Fatalf("Expected 0 zones, got %d", len(zones)) + } +}