From 4bed72fbbeb884c46df055d012f1c940f845d61b Mon Sep 17 00:00:00 2001 From: Philip Laine Date: Fri, 2 Dec 2022 11:10:38 +0100 Subject: [PATCH] Update Azure SDK and remove deprecated autorest dependency --- go.mod | 18 +- go.sum | 28 ++- provider/azure/azure.go | 213 ++++++++--------- provider/azure/azure_private_dns.go | 204 ++++++++-------- provider/azure/azure_privatedns_test.go | 268 +++++++++------------ provider/azure/azure_test.go | 295 +++++++++--------------- provider/azure/config.go | 93 ++++---- provider/azure/config_test.go | 42 +--- 8 files changed, 485 insertions(+), 676 deletions(-) diff --git a/go.mod b/go.mod index f670123cb5..a2eff7fd6a 100644 --- a/go.mod +++ b/go.mod @@ -4,10 +4,12 @@ go 1.19 require ( cloud.google.com/go/compute v1.9.0 - github.com/Azure/azure-sdk-for-go v66.0.0+incompatible - github.com/Azure/go-autorest/autorest v0.11.27 - github.com/Azure/go-autorest/autorest/adal v0.9.20 - github.com/Azure/go-autorest/autorest/to v0.4.0 + github.com/Azure/azure-sdk-for-go/sdk/azcore v1.2.0 + github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.2.0 + github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/dns/armdns v1.0.0 + github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/privatedns/armprivatedns v1.0.0 + github.com/Azure/go-autorest/autorest v0.11.27 // indirect + github.com/Azure/go-autorest/autorest/adal v0.9.20 // indirect github.com/IBM-Cloud/ibm-cloud-cli-sdk v1.0.0 github.com/IBM/go-sdk-core/v5 v5.8.0 github.com/IBM/networking-go-sdk v0.32.0 @@ -77,10 +79,12 @@ require ( require ( code.cloudfoundry.org/gofileutils v0.0.0-20170111115228-4d0c80011a0f // indirect + github.com/Azure/azure-sdk-for-go/sdk/internal v1.1.1 // indirect github.com/Azure/go-autorest v14.2.0+incompatible // indirect github.com/Azure/go-autorest/autorest/date v0.3.0 // indirect github.com/Azure/go-autorest/logger v0.2.1 // indirect github.com/Azure/go-autorest/tracing v0.6.0 // indirect + github.com/AzureAD/microsoft-authentication-library-for-go v0.7.0 // indirect github.com/Masterminds/semver v1.4.2 // indirect github.com/PuerkitoBio/purell v1.1.1 // indirect github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578 // indirect @@ -115,7 +119,7 @@ require ( github.com/go-stack/stack v1.8.0 // indirect github.com/gofrs/uuid v3.2.0+incompatible // indirect github.com/gogo/protobuf v1.3.2 // indirect - github.com/golang-jwt/jwt/v4 v4.2.0 // indirect + github.com/golang-jwt/jwt/v4 v4.4.2 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/protobuf v1.5.2 // indirect github.com/google/gnostic v0.5.7-v3refs // indirect @@ -140,6 +144,7 @@ require ( github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/josharian/intern v1.0.0 // indirect github.com/json-iterator/go v1.1.12 // indirect + github.com/kylelemons/godebug v1.1.0 // indirect github.com/leodido/go-urn v1.2.0 // indirect github.com/mailru/easyjson v0.7.6 // indirect github.com/mattn/go-runewidth v0.0.13 // indirect @@ -156,6 +161,7 @@ require ( github.com/opentracing/opentracing-go v1.2.1-0.20220228012449-10b1cf09e00b // indirect github.com/patrickmn/go-cache v2.1.0+incompatible // indirect github.com/peterhellberg/link v1.1.0 // indirect + github.com/pkg/browser v0.0.0-20210115035449-ce105d075bb4 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/prometheus/client_model v0.2.0 // indirect github.com/prometheus/common v0.32.1 // indirect @@ -174,7 +180,7 @@ require ( go.uber.org/atomic v1.9.0 // indirect go.uber.org/multierr v1.6.0 // indirect go.uber.org/zap v1.19.1 // indirect - golang.org/x/crypto v0.0.0-20220427172511-eb4f295cb31f // indirect + golang.org/x/crypto v0.0.0-20220511200225-c6db032c6c88 // indirect golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 // indirect golang.org/x/sys v0.1.0 // indirect golang.org/x/term v0.1.0 // indirect diff --git a/go.sum b/go.sum index e2c8d9e515..da5d656b03 100644 --- a/go.sum +++ b/go.sum @@ -69,8 +69,16 @@ gioui.org v0.0.0-20210308172011-57750fc8a0a6/go.mod h1:RSH6KIUZ0p2xy5zHDxgAM4zum git.lukeshu.com/go/libsystemd v0.5.3/go.mod h1:FfDoP0i92r4p5Vn4NCLxvjkd7rCOe6otPa4L6hZg9WM= github.com/Azure/azure-sdk-for-go v16.2.1+incompatible/go.mod h1:9XXNKU+eRnpl9moKnB4QOLf1HestfXbmab5FXxiDBjc= github.com/Azure/azure-sdk-for-go v56.0.0+incompatible/go.mod h1:9XXNKU+eRnpl9moKnB4QOLf1HestfXbmab5FXxiDBjc= -github.com/Azure/azure-sdk-for-go v66.0.0+incompatible h1:bmmC38SlE8/E81nNADlgmVGurPWMHDX2YNXVQMrBpEE= -github.com/Azure/azure-sdk-for-go v66.0.0+incompatible/go.mod h1:9XXNKU+eRnpl9moKnB4QOLf1HestfXbmab5FXxiDBjc= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.2.0 h1:sVW/AFBTGyJxDaMYlq0ct3jUXTtj12tQ6zE2GZUgVQw= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.2.0/go.mod h1:uGG2W01BaETf0Ozp+QxxKJdMBNRWPdstHG0Fmdwn1/U= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.2.0 h1:t/W5MYAuQy81cvM8VUNfRLzhtKpXhVUAN7Cd7KVbTyc= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.2.0/go.mod h1:NBanQUfSWiWn3QEpWDTCU0IjBECKOYvl2R8xdRtMtiM= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.1.1 h1:Oj853U9kG+RLTCQXpjvOnrv0WaZHxgmZz1TlLywgOPY= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.1.1/go.mod h1:eWRD7oawr1Mu1sLCawqVc0CUiF43ia3qQMxLscsKQ9w= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/dns/armdns v1.0.0 h1:yxl7xvG5sSVlR74BqjIg+dnoE82jeolZF62X1gMT2VY= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/dns/armdns v1.0.0/go.mod h1:eADizCOKKdr+Q+7TFPNaPh+MIjbfJ42F0snpJZwRAtU= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/privatedns/armprivatedns v1.0.0 h1:QSyXXkeDeNixC785fGcy+VZ2zXsj8rHW1ez6xKvDf9g= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/privatedns/armprivatedns v1.0.0/go.mod h1:5lNwEBWq9A4saepAIVUHxiVeMCjnVT7GXclert1Gpsw= github.com/Azure/go-ansiterm v0.0.0-20170929234023-d6e3b3328b78/go.mod h1:LmzpDX56iTiv29bbRTIsUNlaFfuhWRQBWjQdVyAevI8= github.com/Azure/go-ansiterm v0.0.0-20210608223527-2377c96fe795/go.mod h1:LmzpDX56iTiv29bbRTIsUNlaFfuhWRQBWjQdVyAevI8= github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E= @@ -101,7 +109,6 @@ github.com/Azure/go-autorest/autorest/mocks v0.4.0/go.mod h1:LTp+uSrOhSkaKrUy935 github.com/Azure/go-autorest/autorest/mocks v0.4.1/go.mod h1:LTp+uSrOhSkaKrUy935gNZuuIPPVsHlr9DSOxSayd+k= github.com/Azure/go-autorest/autorest/mocks v0.4.2 h1:PGN4EDXnuQbojHbU0UWoNvmu9AGVwYHG9/fkDYhtAfw= github.com/Azure/go-autorest/autorest/mocks v0.4.2/go.mod h1:Vy7OitM9Kei0i1Oj+LvyAWMXJHeKH1MVlzFugfVrmyU= -github.com/Azure/go-autorest/autorest/to v0.4.0 h1:oXVqrxakqqV1UZdSazDOPOLvOIz+XA683u8EctwboHk= github.com/Azure/go-autorest/autorest/to v0.4.0/go.mod h1:fE8iZBn7LQR7zH/9XU2NcPR4o9jEImooCeWJcYV/zLE= github.com/Azure/go-autorest/autorest/validation v0.1.0/go.mod h1:Ha3z/SqBeaalWQvokg3NZAlQTalVMtOIAs1aGK7G6u8= github.com/Azure/go-autorest/logger v0.1.0/go.mod h1:oExouG+K6PryycPJfVSxi/koC6LSNgds39diKLz7Vrc= @@ -111,6 +118,8 @@ github.com/Azure/go-autorest/logger v0.2.1/go.mod h1:T9E3cAhj2VqvPOtCYAvby9aBXkZ github.com/Azure/go-autorest/tracing v0.5.0/go.mod h1:r/s2XiOKccPW3HrqB+W0TQzfbtp2fGCgRFtBroKn4Dk= github.com/Azure/go-autorest/tracing v0.6.0 h1:TYi4+3m5t6K48TGI9AUdb+IzbnSxvnvUMfuitfgcfuo= github.com/Azure/go-autorest/tracing v0.6.0/go.mod h1:+vhtPC754Xsa23ID7GlGsrdKBpUA79WCAKPPZVC2DeU= +github.com/AzureAD/microsoft-authentication-library-for-go v0.7.0 h1:VgSJlZH5u0k2qxSpqyghcFQKmvYckj46uymKK5XzkBM= +github.com/AzureAD/microsoft-authentication-library-for-go v0.7.0/go.mod h1:BDJ5qMFKx9DugEg3+uQSDCdbYPr5s9vBTrL9P8TpqOU= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/DATA-DOG/go-sqlmock v1.4.1/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM= @@ -361,8 +370,8 @@ github.com/dgryski/go-sip13 v0.0.0-20181026042036-e10d5fee7954/go.mod h1:vAd38F8 github.com/digitalocean/godo v1.64.2/go.mod h1:p7dOjjtSBqCTUksqtA5Fd3uaKs9kyTq2xcz76ulEJRU= github.com/digitalocean/godo v1.81.0 h1:sjb3fOfPfSlUQUK22E87BcI8Zx2qtnF7VUCCO4UK3C8= github.com/digitalocean/godo v1.81.0/go.mod h1:BPCqvwbjbGqxuUnIKB4EvS/AX7IDnNmt5fwvIkWo+ew= -github.com/dnaeon/go-vcr v1.0.1 h1:r8L/HqC0Hje5AXMu1ooW8oyQyOFv4GxqpL0nRP7SLLY= github.com/dnaeon/go-vcr v1.0.1/go.mod h1:aBB1+wY4s93YsC3HHjMBMrwTj2R9FHDzUr9KyGc8n1E= +github.com/dnaeon/go-vcr v1.1.0 h1:ReYa/UBrRyQdant9B4fNHGoCNKw6qh6P0fsdGmZpR7c= github.com/dnsimple/dnsimple-go v0.71.1 h1:1hGoBA3CIjpjZj5DM3081xfxr4e2jYmYnkO2VuBF8Qc= github.com/dnsimple/dnsimple-go v0.71.1/go.mod h1:F9WHww9cC76hrnwGFfAfrqdW99j3MOYasQcIwTS/aUk= github.com/docker/cli v0.0.0-20200130152716-5d0cf8839492/go.mod h1:JLrzqnKDaYBop7H2jaqPtU4hHvMKP+vjCwu2uszcLI8= @@ -616,8 +625,9 @@ github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/goji/httpauth v0.0.0-20160601135302-2da839ab0f4d/go.mod h1:nnjvkQ9ptGaCkuDUx6wNykzzlUixGxvkme+H/lnzb+A= github.com/golang-jwt/jwt/v4 v4.0.0/go.mod h1:/xlHOz8bRuivTWchD4jCa+NbatV+wEUSzwAxVc6locg= -github.com/golang-jwt/jwt/v4 v4.2.0 h1:besgBTC8w8HjP6NzQdxwKH9Z5oQMZ24ThTrHp3cZ8eU= github.com/golang-jwt/jwt/v4 v4.2.0/go.mod h1:/xlHOz8bRuivTWchD4jCa+NbatV+wEUSzwAxVc6locg= +github.com/golang-jwt/jwt/v4 v4.4.2 h1:rcc4lwaZgFMCZ5jxF9ABolDcIHdBytAFgqFPbSJQAYs= +github.com/golang-jwt/jwt/v4 v4.4.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k= github.com/golang/gddo v0.0.0-20190419222130-af0f2af80721/go.mod h1:xEhNfoBDX1hzLm2Nf80qUvZ2sVwoMZ8d6IE2SrsQfh4= @@ -950,6 +960,8 @@ github.com/kr/pty v1.1.5/go.mod h1:9r2w37qlBe7rQ6e1fg1S/9xpWHSnaqNdHD3WcMdbPDA= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= +github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/lann/builder v0.0.0-20180802200727-47ae307949d0/go.mod h1:dXGbAdH5GtBTC4WfIxhKZfyBF/HBFgRZSWwZ9g/He9o= github.com/lann/ps v0.0.0-20150810152359-62de8c46ede0/go.mod h1:vmVJ0l/dxyfGW6FmdpVm2joNMFikkuWg0EoCKLGUMNw= github.com/leodido/go-urn v1.2.0 h1:hpXL4XnriNwQ/ABnpepYM/1vCLWNDfUNts8dX3xTG6Y= @@ -1192,6 +1204,8 @@ github.com/pierrec/lz4 v1.0.2-0.20190131084431-473cd7ce01a1/go.mod h1:3/3N9NVKO0 github.com/pierrec/lz4 v2.0.5+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY= github.com/pierrec/lz4 v2.5.2+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY= github.com/pkg/browser v0.0.0-20180916011732-0a3d74bf9ce4/go.mod h1:4OwLy04Bl9Ef3GJJCoec+30X3LQs/0/m4HFRt/2LUSA= +github.com/pkg/browser v0.0.0-20210115035449-ce105d075bb4 h1:Qj1ukM4GlMWXNdMBuXcXfz/Kw9s1qm0CLY32QxuSImI= +github.com/pkg/browser v0.0.0-20210115035449-ce105d075bb4/go.mod h1:N6UoU20jOqggOuDwUaBQpluzLNDqif3kq9z2wpdYEfQ= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1-0.20171018195549-f15c970de5b7/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= @@ -1562,8 +1576,8 @@ golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3/go.mod h1:IxCIyHEi3zRg3s0 golang.org/x/crypto v0.0.0-20211215165025-cf75a172585e/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8= golang.org/x/crypto v0.0.0-20220112180741-5e0467b6c7ce/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.0.0-20220131195533-30dcbda58838/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= -golang.org/x/crypto v0.0.0-20220427172511-eb4f295cb31f h1:OeJjE6G4dgCY4PIXvIRQbE8+RX+uXZyGhUy/ksMGJoc= -golang.org/x/crypto v0.0.0-20220427172511-eb4f295cb31f/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/crypto v0.0.0-20220511200225-c6db032c6c88 h1:Tgea0cVUD0ivh5ADBX4WwuI12DUd2to3nCYe2eayMIw= +golang.org/x/crypto v0.0.0-20220511200225-c6db032c6c88/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= diff --git a/provider/azure/azure.go b/provider/azure/azure.go index 2b2ff29246..21395bdaa2 100644 --- a/provider/azure/azure.go +++ b/provider/azure/azure.go @@ -23,9 +23,9 @@ import ( log "github.com/sirupsen/logrus" - "github.com/Azure/azure-sdk-for-go/services/dns/mgmt/2018-05-01/dns" - "github.com/Azure/go-autorest/autorest" - "github.com/Azure/go-autorest/autorest/to" + azcoreruntime "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + dns "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/dns/armdns" "sigs.k8s.io/external-dns/endpoint" "sigs.k8s.io/external-dns/plan" @@ -38,14 +38,14 @@ const ( // ZonesClient is an interface of dns.ZoneClient that can be stubbed for testing. type ZonesClient interface { - ListByResourceGroupComplete(ctx context.Context, resourceGroupName string, top *int32) (result dns.ZoneListResultIterator, err error) + NewListByResourceGroupPager(resourceGroupName string, options *dns.ZonesClientListByResourceGroupOptions) *azcoreruntime.Pager[dns.ZonesClientListByResourceGroupResponse] } // RecordSetsClient is an interface of dns.RecordSetsClient that can be stubbed for testing. type RecordSetsClient interface { - ListAllByDNSZoneComplete(ctx context.Context, resourceGroupName string, zoneName string, top *int32, recordSetNameSuffix string) (result dns.RecordSetListResultIterator, err error) - Delete(ctx context.Context, resourceGroupName string, zoneName string, relativeRecordSetName string, recordType dns.RecordType, ifMatch string) (result autorest.Response, err error) - CreateOrUpdate(ctx context.Context, resourceGroupName string, zoneName string, relativeRecordSetName string, recordType dns.RecordType, parameters dns.RecordSet, ifMatch string, ifNoneMatch string) (result dns.RecordSet, err error) + NewListAllByDNSZonePager(resourceGroupName string, zoneName string, options *dns.RecordSetsClientListAllByDNSZoneOptions) *azcoreruntime.Pager[dns.RecordSetsClientListAllByDNSZoneResponse] + Delete(ctx context.Context, resourceGroupName string, zoneName string, relativeRecordSetName string, recordType dns.RecordType, options *dns.RecordSetsClientDeleteOptions) (dns.RecordSetsClientDeleteResponse, error) + CreateOrUpdate(ctx context.Context, resourceGroupName string, zoneName string, relativeRecordSetName string, recordType dns.RecordType, parameters dns.RecordSet, options *dns.RecordSetsClientCreateOrUpdateOptions) (dns.RecordSetsClientCreateOrUpdateResponse, error) } // AzureProvider implements the DNS provider for Microsoft's Azure cloud platform. @@ -69,17 +69,18 @@ func NewAzureProvider(configFile string, domainFilter endpoint.DomainFilter, zon if err != nil { return nil, fmt.Errorf("failed to read Azure config file '%s': %v", configFile, err) } - - token, err := getAccessToken(*cfg, cfg.Environment) + cred, err := getCredentials(*cfg) if err != nil { - return nil, fmt.Errorf("failed to get token: %v", err) + return nil, fmt.Errorf("failed to get credentials: %v", err) + } + zonesClient, err := dns.NewZonesClient(cfg.SubscriptionID, nil, nil) + if err != nil { + return nil, err + } + recordSetsClient, err := dns.NewRecordSetsClient(cfg.SubscriptionID, cred, nil) + if err != nil { + return nil, err } - - zonesClient := dns.NewZonesClientWithBaseURI(cfg.Environment.ResourceManagerEndpoint, cfg.SubscriptionID) - zonesClient.Authorizer = autorest.NewBearerAuthorizer(token) - recordSetsClient := dns.NewRecordSetsClientWithBaseURI(cfg.Environment.ResourceManagerEndpoint, cfg.SubscriptionID) - recordSetsClient.Authorizer = autorest.NewBearerAuthorizer(token) - return &AzureProvider{ domainFilter: domainFilter, zoneNameFilter: zoneNameFilter, @@ -102,43 +103,44 @@ func (p *AzureProvider) Records(ctx context.Context) (endpoints []*endpoint.Endp } for _, zone := range zones { - err := p.iterateRecords(ctx, *zone.Name, func(recordSet dns.RecordSet) bool { - if recordSet.Name == nil || recordSet.Type == nil { - log.Error("Skipping invalid record set with nil name or type.") - return true - } - recordType := strings.TrimPrefix(*recordSet.Type, "Microsoft.Network/dnszones/") - if !provider.SupportedRecordType(recordType) { - return true - } - name := formatAzureDNSName(*recordSet.Name, *zone.Name) - - if len(p.zoneNameFilter.Filters) > 0 && !p.domainFilter.Match(name) { - log.Debugf("Skipping return of record %s because it was filtered out by the specified --domain-filter", name) - return true - } - targets := extractAzureTargets(&recordSet) - if len(targets) == 0 { - log.Debugf("Failed to extract targets for '%s' with type '%s'.", name, recordType) - return true + pager := p.recordSetsClient.NewListAllByDNSZonePager(p.resourceGroup, *zone.Name, nil) + for pager.More() { + nextResult, err := pager.NextPage(ctx) + if err != nil { + return nil, err } - var ttl endpoint.TTL - if recordSet.TTL != nil { - ttl = endpoint.TTL(*recordSet.TTL) + for _, recordSet := range nextResult.Value { + if recordSet.Name == nil || recordSet.Type == nil { + log.Error("Skipping invalid record set with nil name or type.") + continue + } + recordType := strings.TrimPrefix(*recordSet.Type, "Microsoft.Network/dnszones/") + if !provider.SupportedRecordType(recordType) { + continue + } + name := formatAzureDNSName(*recordSet.Name, *zone.Name) + if len(p.zoneNameFilter.Filters) > 0 && !p.domainFilter.Match(name) { + log.Debugf("Skipping return of record %s because it was filtered out by the specified --domain-filter", name) + continue + } + targets := extractAzureTargets(recordSet) + if len(targets) == 0 { + log.Debugf("Failed to extract targets for '%s' with type '%s'.", name, recordType) + continue + } + var ttl endpoint.TTL + if recordSet.Properties.TTL != nil { + ttl = endpoint.TTL(*recordSet.Properties.TTL) + } + ep := endpoint.NewEndpointWithTTL(name, recordType, ttl, targets...) + log.Debugf( + "Found %s record for '%s' with target '%s'.", + ep.RecordType, + ep.DNSName, + ep.Targets, + ) + endpoints = append(endpoints, ep) } - - ep := endpoint.NewEndpointWithTTL(name, recordType, ttl, targets...) - log.Debugf( - "Found %s record for '%s' with target '%s'.", - ep.RecordType, - ep.DNSName, - ep.Targets, - ) - endpoints = append(endpoints, ep) - return true - }) - if err != nil { - return nil, err } } return endpoints, nil @@ -161,56 +163,26 @@ func (p *AzureProvider) ApplyChanges(ctx context.Context, changes *plan.Changes) func (p *AzureProvider) zones(ctx context.Context) ([]dns.Zone, error) { log.Debugf("Retrieving Azure DNS zones for resource group: %s.", p.resourceGroup) - var zones []dns.Zone - - zonesIterator, err := p.zonesClient.ListByResourceGroupComplete(ctx, p.resourceGroup, nil) - if err != nil { - return nil, err - } - - for zonesIterator.NotDone() { - zone := zonesIterator.Value() - - if zone.Name != nil && p.domainFilter.Match(*zone.Name) && p.zoneIDFilter.Match(*zone.ID) { - zones = append(zones, zone) - } else if zone.Name != nil && len(p.zoneNameFilter.Filters) > 0 && p.zoneNameFilter.Match(*zone.Name) { - // Handle zoneNameFilter - zones = append(zones, zone) - } - - err := zonesIterator.NextWithContext(ctx) + pager := p.zonesClient.NewListByResourceGroupPager(p.resourceGroup, nil) + for pager.More() { + nextResult, err := pager.NextPage(ctx) if err != nil { return nil, err } + for _, zone := range nextResult.Value { + if zone.Name != nil && p.domainFilter.Match(*zone.Name) && p.zoneIDFilter.Match(*zone.ID) { + zones = append(zones, *zone) + } else if zone.Name != nil && len(p.zoneNameFilter.Filters) > 0 && p.zoneNameFilter.Match(*zone.Name) { + // Handle zoneNameFilter + zones = append(zones, *zone) + } + } } - log.Debugf("Found %d Azure DNS zone(s).", len(zones)) return zones, nil } -func (p *AzureProvider) iterateRecords(ctx context.Context, zoneName string, callback func(dns.RecordSet) bool) error { - log.Debugf("Retrieving Azure DNS records for zone '%s'.", zoneName) - - recordSetsIterator, err := p.recordSetsClient.ListAllByDNSZoneComplete(ctx, p.resourceGroup, zoneName, nil, "") - if err != nil { - return err - } - - for recordSetsIterator.NotDone() { - if !callback(recordSetsIterator.Value()) { - return nil - } - - err := recordSetsIterator.NextWithContext(ctx) - if err != nil { - return err - } - } - - return nil -} - type azureChangeMap map[string][]*endpoint.Endpoint func (p *AzureProvider) mapChanges(zones []dns.Zone, changes *plan.Changes) (azureChangeMap, azureChangeMap) { @@ -267,7 +239,7 @@ func (p *AzureProvider) deleteRecords(ctx context.Context, deleted azureChangeMa log.Infof("Would delete %s record named '%s' for Azure DNS zone '%s'.", ep.RecordType, name, zone) } else { log.Infof("Deleting %s record named '%s' for Azure DNS zone '%s'.", ep.RecordType, name, zone) - if _, err := p.recordSetsClient.Delete(ctx, p.resourceGroup, zone, name, dns.RecordType(ep.RecordType), ""); err != nil { + if _, err := p.recordSetsClient.Delete(ctx, p.resourceGroup, zone, name, dns.RecordType(ep.RecordType), nil); err != nil { log.Errorf( "Failed to delete %s record named '%s' for Azure DNS zone '%s': %v", ep.RecordType, @@ -317,8 +289,7 @@ func (p *AzureProvider) updateRecords(ctx context.Context, updated azureChangeMa name, dns.RecordType(ep.RecordType), recordSet, - "", - "", + nil, ) } if err != nil { @@ -354,36 +325,36 @@ func (p *AzureProvider) newRecordSet(endpoint *endpoint.Endpoint) (dns.RecordSet ttl = int64(endpoint.RecordTTL) } switch dns.RecordType(endpoint.RecordType) { - case dns.A: - aRecords := make([]dns.ARecord, len(endpoint.Targets)) + case dns.RecordTypeA: + aRecords := make([]*dns.ARecord, len(endpoint.Targets)) for i, target := range endpoint.Targets { - aRecords[i] = dns.ARecord{ - Ipv4Address: to.StringPtr(target), + aRecords[i] = &dns.ARecord{ + IPv4Address: to.Ptr(target), } } return dns.RecordSet{ - RecordSetProperties: &dns.RecordSetProperties{ - TTL: to.Int64Ptr(ttl), - ARecords: &aRecords, + Properties: &dns.RecordSetProperties{ + TTL: to.Ptr(ttl), + ARecords: aRecords, }, }, nil - case dns.CNAME: + case dns.RecordTypeCNAME: return dns.RecordSet{ - RecordSetProperties: &dns.RecordSetProperties{ - TTL: to.Int64Ptr(ttl), + Properties: &dns.RecordSetProperties{ + TTL: to.Ptr(ttl), CnameRecord: &dns.CnameRecord{ - Cname: to.StringPtr(endpoint.Targets[0]), + Cname: to.Ptr(endpoint.Targets[0]), }, }, }, nil - case dns.TXT: + case dns.RecordTypeTXT: return dns.RecordSet{ - RecordSetProperties: &dns.RecordSetProperties{ - TTL: to.Int64Ptr(ttl), - TxtRecords: &[]dns.TxtRecord{ + Properties: &dns.RecordSetProperties{ + TTL: to.Ptr(ttl), + TxtRecords: []*dns.TxtRecord{ { - Value: &[]string{ - endpoint.Targets[0], + Value: []*string{ + &endpoint.Targets[0], }, }, }, @@ -403,17 +374,17 @@ func formatAzureDNSName(recordName, zoneName string) string { // Helper function (shared with text code) func extractAzureTargets(recordSet *dns.RecordSet) []string { - properties := recordSet.RecordSetProperties + properties := recordSet.Properties if properties == nil { return []string{} } // Check for A records aRecords := properties.ARecords - if aRecords != nil && len(*aRecords) > 0 && (*aRecords)[0].Ipv4Address != nil { - targets := make([]string, len(*aRecords)) - for i, aRecord := range *aRecords { - targets[i] = *aRecord.Ipv4Address + if aRecords != nil && len(aRecords) > 0 && (aRecords)[0].IPv4Address != nil { + targets := make([]string, len(aRecords)) + for i, aRecord := range aRecords { + targets[i] = *aRecord.IPv4Address } return targets } @@ -426,10 +397,10 @@ func extractAzureTargets(recordSet *dns.RecordSet) []string { // Check for TXT records txtRecords := properties.TxtRecords - if txtRecords != nil && len(*txtRecords) > 0 && (*txtRecords)[0].Value != nil { - values := (*txtRecords)[0].Value - if values != nil && len(*values) > 0 { - return []string{(*values)[0]} + if txtRecords != nil && len(txtRecords) > 0 && (txtRecords)[0].Value != nil { + values := (txtRecords)[0].Value + if values != nil && len(values) > 0 { + return []string{*(values)[0]} } } return []string{} diff --git a/provider/azure/azure_private_dns.go b/provider/azure/azure_private_dns.go index 320def42aa..7e7ddc4540 100644 --- a/provider/azure/azure_private_dns.go +++ b/provider/azure/azure_private_dns.go @@ -21,9 +21,9 @@ import ( "fmt" "strings" - "github.com/Azure/azure-sdk-for-go/services/privatedns/mgmt/2018-09-01/privatedns" - "github.com/Azure/go-autorest/autorest" - "github.com/Azure/go-autorest/autorest/to" + azcoreruntime "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + privatedns "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/privatedns/armprivatedns" log "github.com/sirupsen/logrus" "sigs.k8s.io/external-dns/endpoint" @@ -33,14 +33,14 @@ import ( // PrivateZonesClient is an interface of privatedns.PrivateZoneClient that can be stubbed for testing. type PrivateZonesClient interface { - ListByResourceGroupComplete(ctx context.Context, resourceGroupName string, top *int32) (result privatedns.PrivateZoneListResultIterator, err error) + NewListByResourceGroupPager(resourceGroupName string, options *privatedns.PrivateZonesClientListByResourceGroupOptions) *azcoreruntime.Pager[privatedns.PrivateZonesClientListByResourceGroupResponse] } // PrivateRecordSetsClient is an interface of privatedns.RecordSetsClient that can be stubbed for testing. type PrivateRecordSetsClient interface { - ListComplete(ctx context.Context, resourceGroupName string, zoneName string, top *int32, recordSetNameSuffix string) (result privatedns.RecordSetListResultIterator, err error) - Delete(ctx context.Context, resourceGroupName string, privateZoneName string, recordType privatedns.RecordType, relativeRecordSetName string, ifMatch string) (result autorest.Response, err error) - CreateOrUpdate(ctx context.Context, resourceGroupName string, privateZoneName string, recordType privatedns.RecordType, relativeRecordSetName string, parameters privatedns.RecordSet, ifMatch string, ifNoneMatch string) (result privatedns.RecordSet, err error) + NewListPager(resourceGroupName string, privateZoneName string, options *privatedns.RecordSetsClientListOptions) *azcoreruntime.Pager[privatedns.RecordSetsClientListResponse] + Delete(ctx context.Context, resourceGroupName string, privateZoneName string, recordType privatedns.RecordType, relativeRecordSetName string, options *privatedns.RecordSetsClientDeleteOptions) (privatedns.RecordSetsClientDeleteResponse, error) + CreateOrUpdate(ctx context.Context, resourceGroupName string, privateZoneName string, recordType privatedns.RecordType, relativeRecordSetName string, parameters privatedns.RecordSet, options *privatedns.RecordSetsClientCreateOrUpdateOptions) (privatedns.RecordSetsClientCreateOrUpdateResponse, error) } // AzurePrivateDNSProvider implements the DNS provider for Microsoft's Azure Private DNS service @@ -63,17 +63,18 @@ func NewAzurePrivateDNSProvider(configFile string, domainFilter endpoint.DomainF if err != nil { return nil, fmt.Errorf("failed to read Azure config file '%s': %v", configFile, err) } - - token, err := getAccessToken(*cfg, cfg.Environment) + cred, err := getCredentials(*cfg) if err != nil { - return nil, fmt.Errorf("failed to get token: %v", err) + return nil, fmt.Errorf("failed to get credentials: %v", err) + } + zonesClient, err := privatedns.NewPrivateZonesClient(cfg.SubscriptionID, cred, nil) + if err != nil { + return nil, err + } + recordSetsClient, err := privatedns.NewRecordSetsClient(cfg.SubscriptionID, cred, nil) + if err != nil { + return nil, err } - - zonesClient := privatedns.NewPrivateZonesClientWithBaseURI(cfg.Environment.ResourceManagerEndpoint, cfg.SubscriptionID) - zonesClient.Authorizer = autorest.NewBearerAuthorizer(token) - recordSetsClient := privatedns.NewRecordSetsClientWithBaseURI(cfg.Environment.ResourceManagerEndpoint, cfg.SubscriptionID) - recordSetsClient.Authorizer = autorest.NewBearerAuthorizer(token) - return &AzurePrivateDNSProvider{ domainFilter: domainFilter, zoneIDFilter: zoneIDFilter, @@ -97,43 +98,49 @@ func (p *AzurePrivateDNSProvider) Records(ctx context.Context) (endpoints []*end log.Debugf("Retrieving Azure Private DNS Records for resource group '%s'", p.resourceGroup) for _, zone := range zones { - err := p.iterateRecords(ctx, *zone.Name, func(recordSet privatedns.RecordSet) { - var recordType string - if recordSet.Type == nil { - log.Debugf("Skipping invalid record set with missing type.") - return + pager := p.recordSetsClient.NewListPager(p.resourceGroup, *zone.Name, nil) + for pager.More() { + nextResult, err := pager.NextPage(ctx) + if err != nil { + return nil, err } - recordType = strings.TrimPrefix(*recordSet.Type, "Microsoft.Network/privateDnsZones/") - var name string - if recordSet.Name == nil { - log.Debugf("Skipping invalid record set with missing name.") - return - } - name = formatAzureDNSName(*recordSet.Name, *zone.Name) + for _, recordSet := range nextResult.Value { + var recordType string + if recordSet.Type == nil { + log.Debugf("Skipping invalid record set with missing type.") + continue + } + // TODO: Is this still required? + recordType = strings.TrimPrefix(*recordSet.Type, "Microsoft.Network/privateDnsZones/") - targets := extractAzurePrivateDNSTargets(&recordSet) - if len(targets) == 0 { - log.Debugf("Failed to extract targets for '%s' with type '%s'.", name, recordType) - return - } + var name string + if recordSet.Name == nil { + log.Debugf("Skipping invalid record set with missing name.") + continue + } + name = formatAzureDNSName(*recordSet.Name, *zone.Name) - var ttl endpoint.TTL - if recordSet.TTL != nil { - ttl = endpoint.TTL(*recordSet.TTL) - } + targets := extractAzurePrivateDNSTargets(recordSet) + if len(targets) == 0 { + log.Debugf("Failed to extract targets for '%s' with type '%s'.", name, recordType) + continue + } - ep := endpoint.NewEndpointWithTTL(name, recordType, ttl, targets...) - log.Debugf( - "Found %s record for '%s' with target '%s'.", - ep.RecordType, - ep.DNSName, - ep.Targets, - ) - endpoints = append(endpoints, ep) - }) - if err != nil { - return nil, err + var ttl endpoint.TTL + if recordSet.Properties.TTL != nil { + ttl = endpoint.TTL(*recordSet.Properties.TTL) + } + + ep := endpoint.NewEndpointWithTTL(name, recordType, ttl, targets...) + log.Debugf( + "Found %s record for '%s' with target '%s'.", + ep.RecordType, + ep.DNSName, + ep.Targets, + ) + endpoints = append(endpoints, ep) + } } } @@ -164,47 +171,23 @@ func (p *AzurePrivateDNSProvider) zones(ctx context.Context) ([]privatedns.Priva var zones []privatedns.PrivateZone - i, err := p.zonesClient.ListByResourceGroupComplete(ctx, p.resourceGroup, nil) - if err != nil { - return nil, err - } - - for i.NotDone() { - zone := i.Value() - log.Debugf("Validating Zone: %v", *zone.Name) - - if zone.Name != nil && p.domainFilter.Match(*zone.Name) && p.zoneIDFilter.Match(*zone.ID) { - zones = append(zones, zone) - } - - err := i.NextWithContext(ctx) + pager := p.zonesClient.NewListByResourceGroupPager(p.resourceGroup, nil) + for pager.More() { + nextResult, err := pager.NextPage(ctx) if err != nil { return nil, err } - } + for _, zone := range nextResult.Value { + log.Debugf("Validating Zone: %v", *zone.Name) - log.Debugf("Found %d Azure Private DNS zone(s).", len(zones)) - return zones, nil -} - -func (p *AzurePrivateDNSProvider) iterateRecords(ctx context.Context, zoneName string, callback func(privatedns.RecordSet)) error { - log.Debugf("Retrieving Azure Private DNS Records for zone '%s'.", zoneName) - - i, err := p.recordSetsClient.ListComplete(ctx, p.resourceGroup, zoneName, nil, "") - if err != nil { - return err - } - - for i.NotDone() { - callback(i.Value()) - - err := i.NextWithContext(ctx) - if err != nil { - return err + if zone.Name != nil && p.domainFilter.Match(*zone.Name) && p.zoneIDFilter.Match(*zone.ID) { + zones = append(zones, *zone) + } } } - return nil + log.Debugf("Found %d Azure Private DNS zone(s).", len(zones)) + return zones, nil } type azurePrivateDNSChangeMap map[string][]*endpoint.Endpoint @@ -260,7 +243,7 @@ func (p *AzurePrivateDNSProvider) deleteRecords(ctx context.Context, deleted azu log.Infof("Would delete %s record named '%s' for Azure Private DNS zone '%s'.", ep.RecordType, name, zone) } else { log.Infof("Deleting %s record named '%s' for Azure Private DNS zone '%s'.", ep.RecordType, name, zone) - if _, err := p.recordSetsClient.Delete(ctx, p.resourceGroup, zone, privatedns.RecordType(ep.RecordType), name, ""); err != nil { + if _, err := p.recordSetsClient.Delete(ctx, p.resourceGroup, zone, privatedns.RecordType(ep.RecordType), name, nil); err != nil { log.Errorf( "Failed to delete %s record named '%s' for Azure Private DNS zone '%s': %v", ep.RecordType, @@ -307,8 +290,7 @@ func (p *AzurePrivateDNSProvider) updateRecords(ctx context.Context, updated azu privatedns.RecordType(ep.RecordType), name, recordSet, - "", - "", + nil, ) } if err != nil { @@ -344,36 +326,36 @@ func (p *AzurePrivateDNSProvider) newRecordSet(endpoint *endpoint.Endpoint) (pri ttl = int64(endpoint.RecordTTL) } switch privatedns.RecordType(endpoint.RecordType) { - case privatedns.A: - aRecords := make([]privatedns.ARecord, len(endpoint.Targets)) + case privatedns.RecordTypeA: + aRecords := make([]*privatedns.ARecord, len(endpoint.Targets)) for i, target := range endpoint.Targets { - aRecords[i] = privatedns.ARecord{ - Ipv4Address: to.StringPtr(target), + aRecords[i] = &privatedns.ARecord{ + IPv4Address: to.Ptr(target), } } return privatedns.RecordSet{ - RecordSetProperties: &privatedns.RecordSetProperties{ - TTL: to.Int64Ptr(ttl), - ARecords: &aRecords, + Properties: &privatedns.RecordSetProperties{ + TTL: to.Ptr(ttl), + ARecords: aRecords, }, }, nil - case privatedns.CNAME: + case privatedns.RecordTypeCNAME: return privatedns.RecordSet{ - RecordSetProperties: &privatedns.RecordSetProperties{ - TTL: to.Int64Ptr(ttl), + Properties: &privatedns.RecordSetProperties{ + TTL: to.Ptr(ttl), CnameRecord: &privatedns.CnameRecord{ - Cname: to.StringPtr(endpoint.Targets[0]), + Cname: to.Ptr(endpoint.Targets[0]), }, }, }, nil - case privatedns.TXT: + case privatedns.RecordTypeTXT: return privatedns.RecordSet{ - RecordSetProperties: &privatedns.RecordSetProperties{ - TTL: to.Int64Ptr(ttl), - TxtRecords: &[]privatedns.TxtRecord{ + Properties: &privatedns.RecordSetProperties{ + TTL: to.Ptr(ttl), + TxtRecords: []*privatedns.TxtRecord{ { - Value: &[]string{ - endpoint.Targets[0], + Value: []*string{ + &endpoint.Targets[0], }, }, }, @@ -385,17 +367,17 @@ func (p *AzurePrivateDNSProvider) newRecordSet(endpoint *endpoint.Endpoint) (pri // Helper function (shared with test code) func extractAzurePrivateDNSTargets(recordSet *privatedns.RecordSet) []string { - properties := recordSet.RecordSetProperties + properties := recordSet.Properties if properties == nil { return []string{} } // Check for A records aRecords := properties.ARecords - if aRecords != nil && len(*aRecords) > 0 && (*aRecords)[0].Ipv4Address != nil { - targets := make([]string, len(*aRecords)) - for i, aRecord := range *aRecords { - targets[i] = *aRecord.Ipv4Address + if aRecords != nil && len(aRecords) > 0 && (aRecords)[0].IPv4Address != nil { + targets := make([]string, len(aRecords)) + for i, aRecord := range aRecords { + targets[i] = *aRecord.IPv4Address } return targets } @@ -408,10 +390,10 @@ func extractAzurePrivateDNSTargets(recordSet *privatedns.RecordSet) []string { // Check for TXT records txtRecords := properties.TxtRecords - if txtRecords != nil && len(*txtRecords) > 0 && (*txtRecords)[0].Value != nil { - values := (*txtRecords)[0].Value - if values != nil && len(*values) > 0 { - return []string{(*values)[0]} + if txtRecords != nil && len(txtRecords) > 0 && (txtRecords)[0].Value != nil { + values := (txtRecords)[0].Value + if values != nil && len(values) > 0 { + return []string{*(values)[0]} } } return []string{} diff --git a/provider/azure/azure_privatedns_test.go b/provider/azure/azure_privatedns_test.go index f357201548..3d00b0d7f0 100644 --- a/provider/azure/azure_privatedns_test.go +++ b/provider/azure/azure_privatedns_test.go @@ -20,9 +20,9 @@ import ( "context" "testing" - "github.com/Azure/azure-sdk-for-go/services/privatedns/mgmt/2018-09-01/privatedns" - "github.com/Azure/go-autorest/autorest" - "github.com/Azure/go-autorest/autorest/to" + azcoreruntime "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + privatedns "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/privatedns/armprivatedns" "sigs.k8s.io/external-dns/endpoint" "sigs.k8s.io/external-dns/plan" "sigs.k8s.io/external-dns/provider" @@ -35,100 +35,126 @@ const ( // mockPrivateZonesClient implements the methods of the Azure Private DNS Zones Client which are used in the Azure Private DNS Provider // and returns static results which are defined per test type mockPrivateZonesClient struct { - mockZonesClientIterator *privatedns.PrivateZoneListResultIterator + pagingHandler azcoreruntime.PagingHandler[privatedns.PrivateZonesClientListByResourceGroupResponse] +} + +func newMockPrivateZonesClient(zones []*privatedns.PrivateZone) mockPrivateZonesClient { + pagingHandler := azcoreruntime.PagingHandler[privatedns.PrivateZonesClientListByResourceGroupResponse]{ + More: func(resp privatedns.PrivateZonesClientListByResourceGroupResponse) bool { + return false + }, + Fetcher: func(context.Context, *privatedns.PrivateZonesClientListByResourceGroupResponse) (privatedns.PrivateZonesClientListByResourceGroupResponse, error) { + return privatedns.PrivateZonesClientListByResourceGroupResponse{ + PrivateZoneListResult: privatedns.PrivateZoneListResult{ + Value: zones, + }, + }, nil + }, + } + return mockPrivateZonesClient{ + pagingHandler: pagingHandler, + } +} + +func (client *mockPrivateZonesClient) NewListByResourceGroupPager(resourceGroupName string, options *privatedns.PrivateZonesClientListByResourceGroupOptions) *azcoreruntime.Pager[privatedns.PrivateZonesClientListByResourceGroupResponse] { + return azcoreruntime.NewPager(client.pagingHandler) } // mockPrivateRecordSetsClient implements the methods of the Azure Private DNS RecordSet Client which are used in the Azure Private DNS Provider // and returns static results which are defined per test type mockPrivateRecordSetsClient struct { - mockRecordSetListIterator *privatedns.RecordSetListResultIterator - deletedEndpoints []*endpoint.Endpoint - updatedEndpoints []*endpoint.Endpoint + pagingHandler azcoreruntime.PagingHandler[privatedns.RecordSetsClientListResponse] + deletedEndpoints []*endpoint.Endpoint + updatedEndpoints []*endpoint.Endpoint } -// mockPrivateZoneListResultPageIterator is used to paginate forward through a list of zones -type mockPrivateZoneListResultPageIterator struct { - offset int - results []privatedns.PrivateZoneListResult -} - -// getNextPage provides the next page based on the offset of the mockZoneListResultPageIterator -func (m *mockPrivateZoneListResultPageIterator) getNextPage(context.Context, privatedns.PrivateZoneListResult) (privatedns.PrivateZoneListResult, error) { - // it assumed that instances of this kind of iterator are only skimmed through once per test - // otherwise a real implementation is required, e.g. based on a linked list - if m.offset < len(m.results) { - m.offset++ - return m.results[m.offset-1], nil +func newMockPrivateRecordSectsClient(recordSets []*privatedns.RecordSet) mockPrivateRecordSetsClient { + pagingHandler := azcoreruntime.PagingHandler[privatedns.RecordSetsClientListResponse]{ + More: func(resp privatedns.RecordSetsClientListResponse) bool { + return false + }, + Fetcher: func(context.Context, *privatedns.RecordSetsClientListResponse) (privatedns.RecordSetsClientListResponse, error) { + return privatedns.RecordSetsClientListResponse{ + RecordSetListResult: privatedns.RecordSetListResult{ + Value: recordSets, + }, + }, nil + }, + } + return mockPrivateRecordSetsClient{ + pagingHandler: pagingHandler, } - - // paged to last page or empty - return privatedns.PrivateZoneListResult{}, nil } -// mockPrivateRecordSetListResultPageIterator is used to paginate forward through a list of recordsets -type mockPrivateRecordSetListResultPageIterator struct { - offset int - results []privatedns.RecordSetListResult +func (client *mockPrivateRecordSetsClient) NewListPager(resourceGroupName string, privateZoneName string, options *privatedns.RecordSetsClientListOptions) *azcoreruntime.Pager[privatedns.RecordSetsClientListResponse] { + return azcoreruntime.NewPager(client.pagingHandler) } -// getNextPage provides the next page based on the offset of the mockRecordSetListResultPageIterator -func (m *mockPrivateRecordSetListResultPageIterator) getNextPage(context.Context, privatedns.RecordSetListResult) (privatedns.RecordSetListResult, error) { - // it assumed that instances of this kind of iterator are only skimmed through once per test - // otherwise a real implementation is required, e.g. based on a linked list - if m.offset < len(m.results) { - m.offset++ - return m.results[m.offset-1], nil - } - - // paged to last page or empty - return privatedns.RecordSetListResult{}, nil +func (client *mockPrivateRecordSetsClient) Delete(ctx context.Context, resourceGroupName string, privateZoneName string, recordType privatedns.RecordType, relativeRecordSetName string, options *privatedns.RecordSetsClientDeleteOptions) (privatedns.RecordSetsClientDeleteResponse, error) { + client.deletedEndpoints = append( + client.deletedEndpoints, + endpoint.NewEndpoint( + formatAzureDNSName(relativeRecordSetName, privateZoneName), + string(recordType), + "", + ), + ) + return privatedns.RecordSetsClientDeleteResponse{}, nil } -func createMockPrivateZone(zone string, id string) privatedns.PrivateZone { - return privatedns.PrivateZone{ - ID: to.StringPtr(id), - Name: to.StringPtr(zone), +func (client *mockPrivateRecordSetsClient) CreateOrUpdate(ctx context.Context, resourceGroupName string, privateZoneName string, recordType privatedns.RecordType, relativeRecordSetName string, parameters privatedns.RecordSet, options *privatedns.RecordSetsClientCreateOrUpdateOptions) (privatedns.RecordSetsClientCreateOrUpdateResponse, error) { + var ttl endpoint.TTL + if parameters.Properties.TTL != nil { + ttl = endpoint.TTL(*parameters.Properties.TTL) } + client.updatedEndpoints = append( + client.updatedEndpoints, + endpoint.NewEndpointWithTTL( + formatAzureDNSName(relativeRecordSetName, privateZoneName), + string(recordType), + ttl, + extractAzurePrivateDNSTargets(¶meters)..., + ), + ) + return privatedns.RecordSetsClientCreateOrUpdateResponse{}, nil + //return parameters, nil } -func (client *mockPrivateZonesClient) ListByResourceGroupComplete(ctx context.Context, resourceGroupName string, top *int32) (result privatedns.PrivateZoneListResultIterator, err error) { - // pre-iterate to first item to emulate behaviour of Azure SDK - err = client.mockZonesClientIterator.NextWithContext(ctx) - if err != nil { - return *client.mockZonesClientIterator, err +func createMockPrivateZone(zone string, id string) *privatedns.PrivateZone { + return &privatedns.PrivateZone{ + ID: to.Ptr(id), + Name: to.Ptr(zone), } - - return *client.mockZonesClientIterator, nil } func privateARecordSetPropertiesGetter(values []string, ttl int64) *privatedns.RecordSetProperties { - aRecords := make([]privatedns.ARecord, len(values)) + aRecords := make([]*privatedns.ARecord, len(values)) for i, value := range values { - aRecords[i] = privatedns.ARecord{ - Ipv4Address: to.StringPtr(value), + aRecords[i] = &privatedns.ARecord{ + IPv4Address: to.Ptr(value), } } return &privatedns.RecordSetProperties{ - TTL: to.Int64Ptr(ttl), - ARecords: &aRecords, + TTL: to.Ptr(ttl), + ARecords: aRecords, } } func privateCNameRecordSetPropertiesGetter(values []string, ttl int64) *privatedns.RecordSetProperties { return &privatedns.RecordSetProperties{ - TTL: to.Int64Ptr(ttl), + TTL: to.Ptr(ttl), CnameRecord: &privatedns.CnameRecord{ - Cname: to.StringPtr(values[0]), + Cname: to.Ptr(values[0]), }, } } func privateTxtRecordSetPropertiesGetter(values []string, ttl int64) *privatedns.RecordSetProperties { return &privatedns.RecordSetProperties{ - TTL: to.Int64Ptr(ttl), - TxtRecords: &[]privatedns.TxtRecord{ + TTL: to.Ptr(ttl), + TxtRecords: []*privatedns.TxtRecord{ { - Value: &[]string{values[0]}, + Value: []*string{&values[0]}, }, }, } @@ -136,19 +162,19 @@ func privateTxtRecordSetPropertiesGetter(values []string, ttl int64) *privatedns func privateOthersRecordSetPropertiesGetter(values []string, ttl int64) *privatedns.RecordSetProperties { return &privatedns.RecordSetProperties{ - TTL: to.Int64Ptr(ttl), + TTL: to.Ptr(ttl), } } -func createPrivateMockRecordSet(name, recordType string, values ...string) privatedns.RecordSet { +func createPrivateMockRecordSet(name, recordType string, values ...string) *privatedns.RecordSet { return createPrivateMockRecordSetMultiWithTTL(name, recordType, 0, values...) } -func createPrivateMockRecordSetWithTTL(name, recordType, value string, ttl int64) privatedns.RecordSet { +func createPrivateMockRecordSetWithTTL(name, recordType, value string, ttl int64) *privatedns.RecordSet { return createPrivateMockRecordSetMultiWithTTL(name, recordType, ttl, value) } -func createPrivateMockRecordSetMultiWithTTL(name, recordType string, ttl int64, values ...string) privatedns.RecordSet { +func createPrivateMockRecordSetMultiWithTTL(name, recordType string, ttl int64, values ...string) *privatedns.RecordSet { var getterFunc func(values []string, ttl int64) *privatedns.RecordSetProperties switch recordType { @@ -161,84 +187,17 @@ func createPrivateMockRecordSetMultiWithTTL(name, recordType string, ttl int64, default: getterFunc = privateOthersRecordSetPropertiesGetter } - return privatedns.RecordSet{ - Name: to.StringPtr(name), - Type: to.StringPtr("Microsoft.Network/privateDnsZones/" + recordType), - RecordSetProperties: getterFunc(values, ttl), + return &privatedns.RecordSet{ + Name: to.Ptr(name), + Type: to.Ptr("Microsoft.Network/privateDnsZones/" + recordType), + Properties: getterFunc(values, ttl), } } -func (client *mockPrivateRecordSetsClient) ListComplete(ctx context.Context, resourceGroupName string, zoneName string, top *int32, recordSetNameSuffix string) (result privatedns.RecordSetListResultIterator, err error) { - // pre-iterate to first item to emulate behaviour of Azure SDK - err = client.mockRecordSetListIterator.NextWithContext(ctx) - if err != nil { - return *client.mockRecordSetListIterator, err - } - - return *client.mockRecordSetListIterator, nil -} - -func (client *mockPrivateRecordSetsClient) Delete(ctx context.Context, resourceGroupName string, privateZoneName string, recordType privatedns.RecordType, relativeRecordSetName string, ifMatch string) (result autorest.Response, err error) { - client.deletedEndpoints = append( - client.deletedEndpoints, - endpoint.NewEndpoint( - formatAzureDNSName(relativeRecordSetName, privateZoneName), - string(recordType), - "", - ), - ) - return autorest.Response{}, nil -} - -func (client *mockPrivateRecordSetsClient) CreateOrUpdate(ctx context.Context, resourceGroupName string, privateZoneName string, recordType privatedns.RecordType, relativeRecordSetName string, parameters privatedns.RecordSet, ifMatch string, ifNoneMatch string) (result privatedns.RecordSet, err error) { - var ttl endpoint.TTL - if parameters.TTL != nil { - ttl = endpoint.TTL(*parameters.TTL) - } - client.updatedEndpoints = append( - client.updatedEndpoints, - endpoint.NewEndpointWithTTL( - formatAzureDNSName(relativeRecordSetName, privateZoneName), - string(recordType), - ttl, - extractAzurePrivateDNSTargets(¶meters)..., - ), - ) - return parameters, nil -} - // newMockedAzurePrivateDNSProvider creates an AzureProvider comprising the mocked clients for zones and recordsets -func newMockedAzurePrivateDNSProvider(domainFilter endpoint.DomainFilter, zoneIDFilter provider.ZoneIDFilter, dryRun bool, resourceGroup string, zones *[]privatedns.PrivateZone, recordSets *[]privatedns.RecordSet) (*AzurePrivateDNSProvider, error) { - // init zone-related parts of the mock-client - pageIterator := mockPrivateZoneListResultPageIterator{ - results: []privatedns.PrivateZoneListResult{ - { - Value: zones, - }, - }, - } - - mockZoneListResultPage := privatedns.NewPrivateZoneListResultPage(privatedns.PrivateZoneListResult{}, pageIterator.getNextPage) - mockZoneClientIterator := privatedns.NewPrivateZoneListResultIterator(mockZoneListResultPage) - zonesClient := mockPrivateZonesClient{ - mockZonesClientIterator: &mockZoneClientIterator, - } - - // init record-related parts of the mock-client - resultPageIterator := mockPrivateRecordSetListResultPageIterator{ - results: []privatedns.RecordSetListResult{ - { - Value: recordSets, - }, - }, - } - - mockRecordSetListResultPage := privatedns.NewRecordSetListResultPage(privatedns.RecordSetListResult{}, resultPageIterator.getNextPage) - mockRecordSetListIterator := privatedns.NewRecordSetListResultIterator(mockRecordSetListResultPage) - recordSetsClient := mockPrivateRecordSetsClient{ - mockRecordSetListIterator: &mockRecordSetListIterator, - } - +func newMockedAzurePrivateDNSProvider(domainFilter endpoint.DomainFilter, zoneIDFilter provider.ZoneIDFilter, dryRun bool, resourceGroup string, zones []*privatedns.PrivateZone, recordSets []*privatedns.RecordSet) (*AzurePrivateDNSProvider, error) { + zonesClient := newMockPrivateZonesClient(zones) + recordSetsClient := newMockPrivateRecordSectsClient(recordSets) return newAzurePrivateDNSProvider(domainFilter, zoneIDFilter, dryRun, resourceGroup, &zonesClient, &recordSetsClient), nil } @@ -255,10 +214,10 @@ func newAzurePrivateDNSProvider(domainFilter endpoint.DomainFilter, zoneIDFilter func TestAzurePrivateDNSRecord(t *testing.T) { provider, err := newMockedAzurePrivateDNSProvider(endpoint.NewDomainFilter([]string{"example.com"}), provider.NewZoneIDFilter([]string{""}), true, "k8s", - &[]privatedns.PrivateZone{ + []*privatedns.PrivateZone{ createMockPrivateZone("example.com", "/privateDnsZones/example.com"), }, - &[]privatedns.RecordSet{ + []*privatedns.RecordSet{ createPrivateMockRecordSet("@", "NS", "ns1-03.azure-dns.com."), createPrivateMockRecordSet("@", "SOA", "Email: azuredns-hostmaster.microsoft.com"), createPrivateMockRecordSet("@", endpoint.RecordTypeA, "123.123.123.122"), @@ -288,10 +247,10 @@ func TestAzurePrivateDNSRecord(t *testing.T) { func TestAzurePrivateDNSMultiRecord(t *testing.T) { provider, err := newMockedAzurePrivateDNSProvider(endpoint.NewDomainFilter([]string{"example.com"}), provider.NewZoneIDFilter([]string{""}), true, "k8s", - &[]privatedns.PrivateZone{ + []*privatedns.PrivateZone{ createMockPrivateZone("example.com", "/privateDnsZones/example.com"), }, - &[]privatedns.RecordSet{ + []*privatedns.RecordSet{ createPrivateMockRecordSet("@", "NS", "ns1-03.azure-dns.com."), createPrivateMockRecordSet("@", "SOA", "Email: azuredns-hostmaster.microsoft.com"), createPrivateMockRecordSet("@", endpoint.RecordTypeA, "123.123.123.122", "234.234.234.233"), @@ -356,30 +315,11 @@ func TestAzurePrivateDNSApplyChangesDryRun(t *testing.T) { } func testAzurePrivateDNSApplyChangesInternal(t *testing.T, dryRun bool, client PrivateRecordSetsClient) { - zlr := privatedns.PrivateZoneListResult{ - Value: &[]privatedns.PrivateZone{ - createMockPrivateZone("example.com", "/privateDnsZones/example.com"), - createMockPrivateZone("other.com", "/privateDnsZones/other.com"), - }, - } - - results := []privatedns.PrivateZoneListResult{ - zlr, - } - - mockZoneListResultPage := privatedns.NewPrivateZoneListResultPage(privatedns.PrivateZoneListResult{}, func(ctxParam context.Context, zlrParam privatedns.PrivateZoneListResult) (privatedns.PrivateZoneListResult, error) { - if len(results) > 0 { - result := results[0] - results = nil - return result, nil - } - return privatedns.PrivateZoneListResult{}, nil - }) - mockZoneClientIterator := privatedns.NewPrivateZoneListResultIterator(mockZoneListResultPage) - - zonesClient := mockPrivateZonesClient{ - mockZonesClientIterator: &mockZoneClientIterator, + zones := []*privatedns.PrivateZone{ + createMockPrivateZone("example.com", "/privateDnsZones/example.com"), + createMockPrivateZone("other.com", "/privateDnsZones/other.com"), } + zonesClient := newMockPrivateZonesClient(zones) provider := newAzurePrivateDNSProvider( endpoint.NewDomainFilter([]string{""}), diff --git a/provider/azure/azure_test.go b/provider/azure/azure_test.go index 0598dd4f3c..c5ef811996 100644 --- a/provider/azure/azure_test.go +++ b/provider/azure/azure_test.go @@ -20,9 +20,9 @@ import ( "context" "testing" - "github.com/Azure/azure-sdk-for-go/services/dns/mgmt/2018-05-01/dns" - "github.com/Azure/go-autorest/autorest" - "github.com/Azure/go-autorest/autorest/to" + azcoreruntime "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + dns "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/dns/armdns" "github.com/stretchr/testify/assert" "sigs.k8s.io/external-dns/endpoint" @@ -34,100 +34,125 @@ import ( // mockZonesClient implements the methods of the Azure DNS Zones Client which are used in the Azure Provider // and returns static results which are defined per test type mockZonesClient struct { - mockZonesClientIterator *dns.ZoneListResultIterator + pagingHandler azcoreruntime.PagingHandler[dns.ZonesClientListByResourceGroupResponse] +} + +func newMockZonesClient(zones []*dns.Zone) mockZonesClient { + pagingHandler := azcoreruntime.PagingHandler[dns.ZonesClientListByResourceGroupResponse]{ + More: func(resp dns.ZonesClientListByResourceGroupResponse) bool { + return false + }, + Fetcher: func(context.Context, *dns.ZonesClientListByResourceGroupResponse) (dns.ZonesClientListByResourceGroupResponse, error) { + return dns.ZonesClientListByResourceGroupResponse{ + ZoneListResult: dns.ZoneListResult{ + Value: zones, + }, + }, nil + }, + } + return mockZonesClient{ + pagingHandler: pagingHandler, + } +} + +func (client *mockZonesClient) NewListByResourceGroupPager(resourceGroupName string, options *dns.ZonesClientListByResourceGroupOptions) *azcoreruntime.Pager[dns.ZonesClientListByResourceGroupResponse] { + return azcoreruntime.NewPager(client.pagingHandler) } // mockZonesClient implements the methods of the Azure DNS RecordSet Client which are used in the Azure Provider // and returns static results which are defined per test type mockRecordSetsClient struct { - mockRecordSetListIterator *dns.RecordSetListResultIterator - deletedEndpoints []*endpoint.Endpoint - updatedEndpoints []*endpoint.Endpoint -} - -// mockZoneListResultPageIterator is used to paginate forward through a list of zones -type mockZoneListResultPageIterator struct { - offset int - results []dns.ZoneListResult + pagingHandler azcoreruntime.PagingHandler[dns.RecordSetsClientListAllByDNSZoneResponse] + deletedEndpoints []*endpoint.Endpoint + updatedEndpoints []*endpoint.Endpoint } -// getNextPage provides the next page based on the offset of the mockZoneListResultPageIterator -func (m *mockZoneListResultPageIterator) getNextPage(context.Context, dns.ZoneListResult) (dns.ZoneListResult, error) { - // it assumed that instances of this kind of iterator are only skimmed through once per test - // otherwise a real implementation is required, e.g. based on a linked list - if m.offset < len(m.results) { - m.offset++ - return m.results[m.offset-1], nil +func newMockRecordSetsClient(recordSets []*dns.RecordSet) mockRecordSetsClient { + pagingHandler := azcoreruntime.PagingHandler[dns.RecordSetsClientListAllByDNSZoneResponse]{ + More: func(resp dns.RecordSetsClientListAllByDNSZoneResponse) bool { + return false + }, + Fetcher: func(context.Context, *dns.RecordSetsClientListAllByDNSZoneResponse) (dns.RecordSetsClientListAllByDNSZoneResponse, error) { + return dns.RecordSetsClientListAllByDNSZoneResponse{ + RecordSetListResult: dns.RecordSetListResult{ + Value: recordSets, + }, + }, nil + }, + } + return mockRecordSetsClient{ + pagingHandler: pagingHandler, } - - // paged to last page or empty - return dns.ZoneListResult{}, nil } -// mockZoneListResultPageIterator is used to paginate forward through a list of recordsets -type mockRecordSetListResultPageIterator struct { - offset int - results []dns.RecordSetListResult +func (client *mockRecordSetsClient) NewListAllByDNSZonePager(resourceGroupName string, zoneName string, options *dns.RecordSetsClientListAllByDNSZoneOptions) *azcoreruntime.Pager[dns.RecordSetsClientListAllByDNSZoneResponse] { + return azcoreruntime.NewPager(client.pagingHandler) } -// getNextPage provides the next page based on the offset of the mockRecordSetListResultPageIterator -func (m *mockRecordSetListResultPageIterator) getNextPage(context.Context, dns.RecordSetListResult) (dns.RecordSetListResult, error) { - // it assumed that instances of this kind of iterator are only skimmed through once per test - // otherwise a real implementation is required, e.g. based on a linked list - if m.offset < len(m.results) { - m.offset++ - return m.results[m.offset-1], nil - } - - // paged to last page or empty - return dns.RecordSetListResult{}, nil +func (client *mockRecordSetsClient) Delete(ctx context.Context, resourceGroupName string, zoneName string, relativeRecordSetName string, recordType dns.RecordType, options *dns.RecordSetsClientDeleteOptions) (dns.RecordSetsClientDeleteResponse, error) { + client.deletedEndpoints = append( + client.deletedEndpoints, + endpoint.NewEndpoint( + formatAzureDNSName(relativeRecordSetName, zoneName), + string(recordType), + "", + ), + ) + return dns.RecordSetsClientDeleteResponse{}, nil } -func createMockZone(zone string, id string) dns.Zone { - return dns.Zone{ - ID: to.StringPtr(id), - Name: to.StringPtr(zone), +func (client *mockRecordSetsClient) CreateOrUpdate(ctx context.Context, resourceGroupName string, zoneName string, relativeRecordSetName string, recordType dns.RecordType, parameters dns.RecordSet, options *dns.RecordSetsClientCreateOrUpdateOptions) (dns.RecordSetsClientCreateOrUpdateResponse, error) { + var ttl endpoint.TTL + if parameters.Properties.TTL != nil { + ttl = endpoint.TTL(*parameters.Properties.TTL) } + client.updatedEndpoints = append( + client.updatedEndpoints, + endpoint.NewEndpointWithTTL( + formatAzureDNSName(relativeRecordSetName, zoneName), + string(recordType), + ttl, + extractAzureTargets(¶meters)..., + ), + ) + return dns.RecordSetsClientCreateOrUpdateResponse{}, nil } -func (client *mockZonesClient) ListByResourceGroupComplete(ctx context.Context, resourceGroupName string, top *int32) (result dns.ZoneListResultIterator, err error) { - // pre-iterate to first item to emulate behaviour of Azure SDK - err = client.mockZonesClientIterator.NextWithContext(ctx) - if err != nil { - return *client.mockZonesClientIterator, err +func createMockZone(zone string, id string) *dns.Zone { + return &dns.Zone{ + ID: to.Ptr(id), + Name: to.Ptr(zone), } - - return *client.mockZonesClientIterator, nil } func aRecordSetPropertiesGetter(values []string, ttl int64) *dns.RecordSetProperties { - aRecords := make([]dns.ARecord, len(values)) + aRecords := make([]*dns.ARecord, len(values)) for i, value := range values { - aRecords[i] = dns.ARecord{ - Ipv4Address: to.StringPtr(value), + aRecords[i] = &dns.ARecord{ + IPv4Address: to.Ptr(value), } } return &dns.RecordSetProperties{ - TTL: to.Int64Ptr(ttl), - ARecords: &aRecords, + TTL: to.Ptr(ttl), + ARecords: aRecords, } } func cNameRecordSetPropertiesGetter(values []string, ttl int64) *dns.RecordSetProperties { return &dns.RecordSetProperties{ - TTL: to.Int64Ptr(ttl), + TTL: to.Ptr(ttl), CnameRecord: &dns.CnameRecord{ - Cname: to.StringPtr(values[0]), + Cname: to.Ptr(values[0]), }, } } func txtRecordSetPropertiesGetter(values []string, ttl int64) *dns.RecordSetProperties { return &dns.RecordSetProperties{ - TTL: to.Int64Ptr(ttl), - TxtRecords: &[]dns.TxtRecord{ + TTL: to.Ptr(ttl), + TxtRecords: []*dns.TxtRecord{ { - Value: &[]string{values[0]}, + Value: []*string{to.Ptr(values[0])}, }, }, } @@ -135,19 +160,19 @@ func txtRecordSetPropertiesGetter(values []string, ttl int64) *dns.RecordSetProp func othersRecordSetPropertiesGetter(values []string, ttl int64) *dns.RecordSetProperties { return &dns.RecordSetProperties{ - TTL: to.Int64Ptr(ttl), + TTL: to.Ptr(ttl), } } -func createMockRecordSet(name, recordType string, values ...string) dns.RecordSet { +func createMockRecordSet(name, recordType string, values ...string) *dns.RecordSet { return createMockRecordSetMultiWithTTL(name, recordType, 0, values...) } -func createMockRecordSetWithTTL(name, recordType, value string, ttl int64) dns.RecordSet { +func createMockRecordSetWithTTL(name, recordType, value string, ttl int64) *dns.RecordSet { return createMockRecordSetMultiWithTTL(name, recordType, ttl, value) } -func createMockRecordSetMultiWithTTL(name, recordType string, ttl int64, values ...string) dns.RecordSet { +func createMockRecordSetMultiWithTTL(name, recordType string, ttl int64, values ...string) *dns.RecordSet { var getterFunc func(values []string, ttl int64) *dns.RecordSetProperties switch recordType { @@ -160,84 +185,17 @@ func createMockRecordSetMultiWithTTL(name, recordType string, ttl int64, values default: getterFunc = othersRecordSetPropertiesGetter } - return dns.RecordSet{ - Name: to.StringPtr(name), - Type: to.StringPtr("Microsoft.Network/dnszones/" + recordType), - RecordSetProperties: getterFunc(values, ttl), - } -} - -func (client *mockRecordSetsClient) ListAllByDNSZoneComplete(ctx context.Context, resourceGroupName string, zoneName string, top *int32, recordSetNameSuffix string) (result dns.RecordSetListResultIterator, err error) { - // pre-iterate to first item to emulate behaviour of Azure SDK - err = client.mockRecordSetListIterator.NextWithContext(ctx) - if err != nil { - return *client.mockRecordSetListIterator, err + return &dns.RecordSet{ + Name: to.Ptr(name), + Type: to.Ptr("Microsoft.Network/dnszones/" + recordType), + Properties: getterFunc(values, ttl), } - - return *client.mockRecordSetListIterator, nil -} - -func (client *mockRecordSetsClient) Delete(ctx context.Context, resourceGroupName string, zoneName string, relativeRecordSetName string, recordType dns.RecordType, ifMatch string) (result autorest.Response, err error) { - client.deletedEndpoints = append( - client.deletedEndpoints, - endpoint.NewEndpoint( - formatAzureDNSName(relativeRecordSetName, zoneName), - string(recordType), - "", - ), - ) - return autorest.Response{}, nil -} - -func (client *mockRecordSetsClient) CreateOrUpdate(ctx context.Context, resourceGroupName string, zoneName string, relativeRecordSetName string, recordType dns.RecordType, parameters dns.RecordSet, ifMatch string, ifNoneMatch string) (result dns.RecordSet, err error) { - var ttl endpoint.TTL - if parameters.TTL != nil { - ttl = endpoint.TTL(*parameters.TTL) - } - client.updatedEndpoints = append( - client.updatedEndpoints, - endpoint.NewEndpointWithTTL( - formatAzureDNSName(relativeRecordSetName, zoneName), - string(recordType), - ttl, - extractAzureTargets(¶meters)..., - ), - ) - return parameters, nil } // newMockedAzureProvider creates an AzureProvider comprising the mocked clients for zones and recordsets -func newMockedAzureProvider(domainFilter endpoint.DomainFilter, zoneNameFilter endpoint.DomainFilter, zoneIDFilter provider.ZoneIDFilter, dryRun bool, resourceGroup string, userAssignedIdentityClientID string, zones *[]dns.Zone, recordSets *[]dns.RecordSet) (*AzureProvider, error) { - // init zone-related parts of the mock-client - pageIterator := mockZoneListResultPageIterator{ - results: []dns.ZoneListResult{ - { - Value: zones, - }, - }, - } - - mockZoneListResultPage := dns.NewZoneListResultPage(dns.ZoneListResult{}, pageIterator.getNextPage) - mockZoneClientIterator := dns.NewZoneListResultIterator(mockZoneListResultPage) - zonesClient := mockZonesClient{ - mockZonesClientIterator: &mockZoneClientIterator, - } - - // init record-related parts of the mock-client - resultPageIterator := mockRecordSetListResultPageIterator{ - results: []dns.RecordSetListResult{ - { - Value: recordSets, - }, - }, - } - - mockRecordSetListResultPage := dns.NewRecordSetListResultPage(dns.RecordSetListResult{}, resultPageIterator.getNextPage) - mockRecordSetListIterator := dns.NewRecordSetListResultIterator(mockRecordSetListResultPage) - recordSetsClient := mockRecordSetsClient{ - mockRecordSetListIterator: &mockRecordSetListIterator, - } - +func newMockedAzureProvider(domainFilter endpoint.DomainFilter, zoneNameFilter endpoint.DomainFilter, zoneIDFilter provider.ZoneIDFilter, dryRun bool, resourceGroup string, userAssignedIdentityClientID string, zones []*dns.Zone, recordSets []*dns.RecordSet) (*AzureProvider, error) { + zonesClient := newMockZonesClient(zones) + recordSetsClient := newMockRecordSetsClient(recordSets) return newAzureProvider(domainFilter, zoneNameFilter, zoneIDFilter, dryRun, resourceGroup, userAssignedIdentityClientID, &zonesClient, &recordSetsClient), nil } @@ -260,10 +218,10 @@ func validateAzureEndpoints(t *testing.T, endpoints []*endpoint.Endpoint, expect func TestAzureRecord(t *testing.T) { provider, err := newMockedAzureProvider(endpoint.NewDomainFilter([]string{"example.com"}), endpoint.NewDomainFilter([]string{}), provider.NewZoneIDFilter([]string{""}), true, "k8s", "", - &[]dns.Zone{ + []*dns.Zone{ createMockZone("example.com", "/dnszones/example.com"), }, - &[]dns.RecordSet{ + []*dns.RecordSet{ createMockRecordSet("@", "NS", "ns1-03.azure-dns.com."), createMockRecordSet("@", "SOA", "Email: azuredns-hostmaster.microsoft.com"), createMockRecordSet("@", endpoint.RecordTypeA, "123.123.123.122"), @@ -294,10 +252,10 @@ func TestAzureRecord(t *testing.T) { func TestAzureMultiRecord(t *testing.T) { provider, err := newMockedAzureProvider(endpoint.NewDomainFilter([]string{"example.com"}), endpoint.NewDomainFilter([]string{}), provider.NewZoneIDFilter([]string{""}), true, "k8s", "", - &[]dns.Zone{ + []*dns.Zone{ createMockZone("example.com", "/dnszones/example.com"), }, - &[]dns.RecordSet{ + []*dns.RecordSet{ createMockRecordSet("@", "NS", "ns1-03.azure-dns.com."), createMockRecordSet("@", "SOA", "Email: azuredns-hostmaster.microsoft.com"), createMockRecordSet("@", endpoint.RecordTypeA, "123.123.123.122", "234.234.234.233"), @@ -363,30 +321,11 @@ func TestAzureApplyChangesDryRun(t *testing.T) { } func testAzureApplyChangesInternal(t *testing.T, dryRun bool, client RecordSetsClient) { - zlr := dns.ZoneListResult{ - Value: &[]dns.Zone{ - createMockZone("example.com", "/dnszones/example.com"), - createMockZone("other.com", "/dnszones/other.com"), - }, - } - - results := []dns.ZoneListResult{ - zlr, - } - - mockZoneListResultPage := dns.NewZoneListResultPage(dns.ZoneListResult{}, func(ctxParam context.Context, zlrParam dns.ZoneListResult) (dns.ZoneListResult, error) { - if len(results) > 0 { - result := results[0] - results = nil - return result, nil - } - return dns.ZoneListResult{}, nil - }) - mockZoneClientIterator := dns.NewZoneListResultIterator(mockZoneListResultPage) - - zonesClient := mockZonesClient{ - mockZonesClientIterator: &mockZoneClientIterator, + zones := []*dns.Zone{ + createMockZone("example.com", "/dnszones/example.com"), + createMockZone("other.com", "/dnszones/other.com"), } + zonesClient := newMockZonesClient(zones) provider := newAzureProvider( endpoint.NewDomainFilter([]string{""}), @@ -443,11 +382,11 @@ func testAzureApplyChangesInternal(t *testing.T, dryRun bool, client RecordSetsC func TestAzureNameFilter(t *testing.T) { provider, err := newMockedAzureProvider(endpoint.NewDomainFilter([]string{"nginx.example.com"}), endpoint.NewDomainFilter([]string{"example.com"}), provider.NewZoneIDFilter([]string{""}), true, "k8s", "", - &[]dns.Zone{ + []*dns.Zone{ createMockZone("example.com", "/dnszones/example.com"), }, - &[]dns.RecordSet{ + []*dns.RecordSet{ createMockRecordSet("@", "NS", "ns1-03.azure-dns.com."), createMockRecordSet("@", "SOA", "Email: azuredns-hostmaster.microsoft.com"), createMockRecordSet("@", endpoint.RecordTypeA, "123.123.123.122"), @@ -496,29 +435,7 @@ func TestAzureApplyChangesZoneName(t *testing.T) { } func testAzureApplyChangesInternalZoneName(t *testing.T, dryRun bool, client RecordSetsClient) { - zlr := dns.ZoneListResult{ - Value: &[]dns.Zone{ - createMockZone("example.com", "/dnszones/example.com"), - }, - } - - results := []dns.ZoneListResult{ - zlr, - } - - mockZoneListResultPage := dns.NewZoneListResultPage(dns.ZoneListResult{}, func(ctxParam context.Context, zlrParam dns.ZoneListResult) (dns.ZoneListResult, error) { - if len(results) > 0 { - result := results[0] - results = nil - return result, nil - } - return dns.ZoneListResult{}, nil - }) - mockZoneClientIterator := dns.NewZoneListResultIterator(mockZoneListResultPage) - - zonesClient := mockZonesClient{ - mockZonesClientIterator: &mockZoneClientIterator, - } + zonesClient := newMockZonesClient([]*dns.Zone{createMockZone("example.com", "/dnszones/example.com")}) provider := newAzureProvider( endpoint.NewDomainFilter([]string{"foo.example.com"}), diff --git a/provider/azure/config.go b/provider/azure/config.go index 67baed947d..8cbfb15492 100644 --- a/provider/azure/config.go +++ b/provider/azure/config.go @@ -21,24 +21,24 @@ import ( "io/ioutil" "strings" - "github.com/Azure/go-autorest/autorest/adal" - "github.com/Azure/go-autorest/autorest/azure" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud" + "github.com/Azure/azure-sdk-for-go/sdk/azidentity" log "github.com/sirupsen/logrus" "gopkg.in/yaml.v2" ) // config represents common config items for Azure DNS and Azure Private DNS type config struct { - Cloud string `json:"cloud" yaml:"cloud"` - Environment azure.Environment `json:"-" yaml:"-"` - TenantID string `json:"tenantId" yaml:"tenantId"` - SubscriptionID string `json:"subscriptionId" yaml:"subscriptionId"` - ResourceGroup string `json:"resourceGroup" yaml:"resourceGroup"` - Location string `json:"location" yaml:"location"` - ClientID string `json:"aadClientId" yaml:"aadClientId"` - ClientSecret string `json:"aadClientSecret" yaml:"aadClientSecret"` - UseManagedIdentityExtension bool `json:"useManagedIdentityExtension" yaml:"useManagedIdentityExtension"` - UserAssignedIdentityID string `json:"userAssignedIdentityID" yaml:"userAssignedIdentityID"` + Cloud string `json:"cloud" yaml:"cloud"` + TenantID string `json:"tenantId" yaml:"tenantId"` + SubscriptionID string `json:"subscriptionId" yaml:"subscriptionId"` + ResourceGroup string `json:"resourceGroup" yaml:"resourceGroup"` + Location string `json:"location" yaml:"location"` + ClientID string `json:"aadClientId" yaml:"aadClientId"` + ClientSecret string `json:"aadClientSecret" yaml:"aadClientSecret"` + UseManagedIdentityExtension bool `json:"useManagedIdentityExtension" yaml:"useManagedIdentityExtension"` + UserAssignedIdentityID string `json:"userAssignedIdentityID" yaml:"userAssignedIdentityID"` } func getConfig(configFile, resourceGroup, userAssignedIdentityClientID string) (*config, error) { @@ -60,23 +60,16 @@ func getConfig(configFile, resourceGroup, userAssignedIdentityClientID string) ( if userAssignedIdentityClientID != "" { cfg.UserAssignedIdentityID = userAssignedIdentityClientID } - - var environment azure.Environment - if cfg.Cloud == "" { - environment = azure.PublicCloud - } else { - environment, err = azure.EnvironmentFromName(cfg.Cloud) - if err != nil { - return nil, fmt.Errorf("invalid cloud value '%s': %v", cfg.Cloud, err) - } - } - cfg.Environment = environment - return cfg, nil } // getAccessToken retrieves Azure API access token. -func getAccessToken(cfg config, environment azure.Environment) (*adal.ServicePrincipalToken, error) { +func getCredentials(cfg config) (azcore.TokenCredential, error) { + cloudCfg, err := getCloudConfiguration(cfg.Cloud) + if err != nil { + return nil, err + } + // Try to retrieve token with service principal credentials. // Try to use service principal first, some AKS clusters are in an intermediate state that `UseManagedIdentityExtension` is `true` // and service principal exists. In this case, we still want to use service principal to authenticate. @@ -88,40 +81,48 @@ func getAccessToken(cfg config, environment azure.Environment) (*adal.ServicePri !strings.EqualFold(cfg.ClientID, "msi") && !strings.EqualFold(cfg.ClientSecret, "msi") { log.Info("Using client_id+client_secret to retrieve access token for Azure API.") - oauthConfig, err := adal.NewOAuthConfig(environment.ActiveDirectoryEndpoint, cfg.TenantID) - if err != nil { - return nil, fmt.Errorf("failed to retrieve OAuth config: %v", err) + opts := &azidentity.ClientSecretCredentialOptions{ + ClientOptions: azcore.ClientOptions{ + Cloud: cloudCfg, + }, } - - token, err := adal.NewServicePrincipalToken(*oauthConfig, cfg.ClientID, cfg.ClientSecret, environment.ResourceManagerEndpoint) + cred, err := azidentity.NewClientSecretCredential(cfg.TenantID, cfg.ClientID, cfg.ClientSecret, opts) if err != nil { - return nil, fmt.Errorf("failed to create service principal token: %v", err) + return nil, fmt.Errorf("failed to create service principal token: %w", err) } - return token, nil + return cred, nil } // Try to retrieve token with MSI. if cfg.UseManagedIdentityExtension { log.Info("Using managed identity extension to retrieve access token for Azure API.") - + msiOpt := azidentity.ManagedIdentityCredentialOptions{ + ClientOptions: azcore.ClientOptions{ + Cloud: cloudCfg, + }, + } if cfg.UserAssignedIdentityID != "" { - log.Infof("Resolving to user assigned identity, client id is %s.", cfg.UserAssignedIdentityID) - token, err := adal.NewServicePrincipalTokenFromManagedIdentity(environment.ServiceManagementEndpoint, &adal.ManagedIdentityOptions{ - ClientID: cfg.UserAssignedIdentityID, - }) - if err != nil { - return nil, fmt.Errorf("failed to create the managed service identity token: %v", err) - } - return token, nil + msiOpt.ID = azidentity.ClientID(cfg.UserAssignedIdentityID) } - - log.Info("Resolving to system assigned identity.") - token, err := adal.NewServicePrincipalTokenFromManagedIdentity(environment.ServiceManagementEndpoint, nil) + cred, err := azidentity.NewManagedIdentityCredential(&msiOpt) if err != nil { - return nil, fmt.Errorf("failed to create the managed service identity token: %v", err) + return nil, fmt.Errorf("failed to create the managed service identity token: %w", err) } - return token, nil + return cred, nil } return nil, fmt.Errorf("no credentials provided for Azure API") } + +func getCloudConfiguration(name string) (cloud.Configuration, error) { + name = strings.ToUpper(name) + switch name { + case "AZURECLOUD", "AZUREPUBLICCLOUD", "": + return cloud.AzurePublic, nil + case "AZUREUSGOVERNMENT", "AZUREUSGOVERNMENTCLOUD": + return cloud.AzureGovernment, nil + case "AZURECHINACLOUD": + return cloud.AzureChina, nil + } + return cloud.Configuration{}, fmt.Errorf("unknown cloud name: %s", name) +} diff --git a/provider/azure/config_test.go b/provider/azure/config_test.go index a2129113a3..7551fa5169 100644 --- a/provider/azure/config_test.go +++ b/provider/azure/config_test.go @@ -17,51 +17,29 @@ limitations under the License. package azure import ( - "fmt" - "io/ioutil" - "os" - "reflect" "testing" - "github.com/Azure/go-autorest/autorest/azure" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud" ) -func TestGetAzureEnvironmentConfig(t *testing.T) { - tmp, err := ioutil.TempFile("", "azureconf") - if err != nil { - t.Errorf("couldn't write temp file %v", err) - } - defer os.Remove(tmp.Name()) - +func TestGetCloudConfiguration(t *testing.T) { tests := map[string]struct { - cloud string - err error + cloudName string + expected cloud.Configuration }{ - "AzureChinaCloud": {"AzureChinaCloud", nil}, - "AzureGermanCloud": {"AzureGermanCloud", nil}, - "AzurePublicCloud": {"", nil}, - "AzureUSGovernment": {"AzureUSGovernmentCloud", nil}, + "AzureChinaCloud": {"AzureChinaCloud", cloud.AzureChina}, + "AzurePublicCloud": {"", cloud.AzurePublic}, + "AzureUSGovernment": {"AzureUSGovernmentCloud", cloud.AzureGovernment}, } for name, test := range tests { t.Run(name, func(t *testing.T) { - _, _ = tmp.Seek(0, 0) - _, _ = tmp.Write([]byte(fmt.Sprintf(`{"cloud": "%s"}`, test.cloud))) - got, err := getConfig(tmp.Name(), "", "") + cloudCfg, err := getCloudConfiguration(test.cloudName) if err != nil { t.Errorf("got unexpected err %v", err) } - - if test.cloud == "" { - test.cloud = "AzurePublicCloud" - } - want, err := azure.EnvironmentFromName(test.cloud) - if err != nil { - t.Errorf("couldn't get azure environment from provided name %v", err) - } - - if !reflect.DeepEqual(want, got.Environment) { - t.Errorf("got %v, want %v", got.Environment, want) + if cloudCfg.ActiveDirectoryAuthorityHost != test.expected.ActiveDirectoryAuthorityHost { + t.Errorf("got %v, want %v", cloudCfg, test.expected) } }) }