diff --git a/src/Accounts/Accounts.Test/AzureRMProfileTests.cs b/src/Accounts/Accounts.Test/AzureRMProfileTests.cs index 66afd5b7de31..370bb438b154 100644 --- a/src/Accounts/Accounts.Test/AzureRMProfileTests.cs +++ b/src/Accounts/Accounts.Test/AzureRMProfileTests.cs @@ -149,6 +149,36 @@ public void SpecifyTenantAndSubscriptionIdSucceed() Assert.Equal("2021-01-01", client.SubscriptionAndTenantClient.ApiVersion); } + [Fact] + [Trait(Category.AcceptanceType, Category.CheckIn)] + public void SpecifyTenantDomainAndSubscriptionIdSucceed() + { + var tenants = new List<string> { DefaultTenant.ToString() }; + var firstList = new List<string> { DefaultSubscription.ToString(), Guid.NewGuid().ToString() }; + var secondList = new List<string> { Guid.NewGuid().ToString() }; + var client = SetupTestEnvironment(tenants, firstList, secondList); + + ((MockTokenAuthenticationFactory)AzureSession.Instance.AuthenticationFactory).TokenProvider = (account, environment, tenant) => + new MockAccessToken + { + UserId = "aaa@contoso.com", + LoginType = LoginType.OrgId, + AccessToken = "bbb", + TenantId = DefaultTenant.ToString() + }; + + var azureRmProfile = client.Login( + Context.Account, + Context.Environment, + MockSubscriptionClientFactory.GetTenantDomainFromId(DefaultTenant.ToString()), + DefaultSubscription.ToString(), + null, + null, + false, + null); + Assert.Equal("2021-01-01", client.SubscriptionAndTenantClient.ApiVersion); + } + [Fact] [Trait(Category.AcceptanceType, Category.CheckIn)] public void SubscriptionIdNotExist() diff --git a/src/Accounts/Accounts.Test/Mocks/MockSubscriptionClientFactory.cs b/src/Accounts/Accounts.Test/Mocks/MockSubscriptionClientFactory.cs index f0315d64a2f4..2c8af2ea2a7f 100644 --- a/src/Accounts/Accounts.Test/Mocks/MockSubscriptionClientFactory.cs +++ b/src/Accounts/Accounts.Test/Mocks/MockSubscriptionClientFactory.cs @@ -52,6 +52,11 @@ public MockSubscriptionClientFactory() { } + public static string GetTenantDomainFromId(string id) + { + return id.Substring(3)+".com"; + } + public static string GetSubscriptionNameFromId(string id) { if(id == "a11a11aa-aaaa-aaaa-aaaa-aaaa1111aaaa" || id == "aaaa11aa-aaaa-aaaa-aaaa-aaaa1111aaaa") diff --git a/src/Accounts/Accounts.Test/Mocks/MockSubscriptionClientFactoryVersion2019.cs b/src/Accounts/Accounts.Test/Mocks/MockSubscriptionClientFactoryVersion2019.cs index 022fbf706024..f613f1e6f806 100644 --- a/src/Accounts/Accounts.Test/Mocks/MockSubscriptionClientFactoryVersion2019.cs +++ b/src/Accounts/Accounts.Test/Mocks/MockSubscriptionClientFactoryVersion2019.cs @@ -48,7 +48,7 @@ public SubscriptionClient GetSubscriptionClientVerLatest() { return ListTenantQueueDequeueVerLatest(); } - var tenants = _tenants.Select((k) => new TenantIdDescription(id: k, tenantId: k)); + var tenants = _tenants.Select((k) => new TenantIdDescription(id: k, tenantId: k, domains: new List<string>{GetTenantDomainFromId(k)})); var mockPage = new MockPage<TenantIdDescription>(tenants.ToList()); AzureOperationResponse<IPage<TenantIdDescription>> r = new AzureOperationResponse<IPage<TenantIdDescription>> diff --git a/src/Accounts/Accounts/ChangeLog.md b/src/Accounts/Accounts/ChangeLog.md index c7165eb0869d..bd4c872083f3 100644 --- a/src/Accounts/Accounts/ChangeLog.md +++ b/src/Accounts/Accounts/ChangeLog.md @@ -19,6 +19,7 @@ --> ## Upcoming Release +* Supported tenant domain as input while using `Connect-AzAccount` with parameter `Tenant`. [#19471] ## Version 2.10.1 * Deduplicated subscriptions belonging to multiple tenants while using `Get-AzSubscription` with parameter `SubscriptionName`. [#19427] diff --git a/src/Accounts/Accounts/Models/RMProfileClient.cs b/src/Accounts/Accounts/Models/RMProfileClient.cs index 022c8302691b..baba101a206c 100644 --- a/src/Accounts/Accounts/Models/RMProfileClient.cs +++ b/src/Accounts/Accounts/Models/RMProfileClient.cs @@ -113,7 +113,7 @@ public bool TryRemoveContext(IAzureContext context) public AzureRmProfile Login( IAzureAccount account, IAzureEnvironment environment, - string tenantId, + string tenantIdOrName, string subscriptionId, string subscriptionName, SecureString password, @@ -138,13 +138,13 @@ public AzureRmProfile Login( bool needDataPlanAuthFirst = !string.IsNullOrEmpty(authScope); if(needDataPlanAuthFirst) { - var token = AcquireAccessToken(account, environment, tenantId, password, promptBehavior, promptAction, authScope); + var token = AcquireAccessToken(account, environment, tenantIdOrName, password, promptBehavior, promptAction, authScope); promptBehavior = ShowDialog.Never; } if (skipValidation) { - if (string.IsNullOrEmpty(subscriptionId) || string.IsNullOrEmpty(tenantId)) + if (string.IsNullOrEmpty(subscriptionId) || string.IsNullOrEmpty(tenantIdOrName)) { throw new PSInvalidOperationException(Resources.SubscriptionOrTenantMissing); } @@ -154,29 +154,31 @@ public AzureRmProfile Login( Id = subscriptionId }; - newSubscription.SetOrAppendProperty(AzureSubscription.Property.Tenants, tenantId); + newSubscription.SetOrAppendProperty(AzureSubscription.Property.Tenants, tenantIdOrName); newSubscription.SetOrAppendProperty(AzureSubscription.Property.Account, account.Id); newTenant = new AzureTenant { - Id = tenantId + Id = tenantIdOrName }; } else { // (tenant and subscription are present) OR // (tenant is present and subscription is not provided) - if (!string.IsNullOrEmpty(tenantId)) + if (!string.IsNullOrEmpty(tenantIdOrName)) { Guid tempGuid = Guid.Empty; - if (!Guid.TryParse(tenantId, out tempGuid)) + if (!Guid.TryParse(tenantIdOrName, out tempGuid)) { var tenants = ListAccountTenants(account, environment, password, promptBehavior, promptAction); - var homeTenants = tenants.FirstOrDefault(t => t.IsHome); - var tenant = homeTenants ?? tenants.FirstOrDefault(); + var matchesName = tenants.Where(t => t.GetPropertyAsArray(AzureTenant.Property.Domains) + .Contains(tenantIdOrName, StringComparer.InvariantCultureIgnoreCase)); + var homeTenants = matchesName.FirstOrDefault(t => t.IsHome); + var tenant = homeTenants ?? matchesName.FirstOrDefault(); if (tenant == null || tenant.Id == null) { - string baseMessage = string.Format(ProfileMessages.TenantDomainNotFound, tenantId); + string baseMessage = string.Format(ProfileMessages.TenantDomainNotFound, tenantIdOrName); var typeMessageMap = new Dictionary<string, string> { { AzureAccount.AccountType.ServicePrincipal, string.Format(ProfileMessages.ServicePrincipalTenantDomainNotFound, account.Id) }, @@ -187,14 +189,14 @@ public AzureRmProfile Login( throw new ArgumentNullException(string.Format("{0} {1}", baseMessage, typeMessage)); } - tenantId = tenant.Id; + tenantIdOrName = tenant.Id; } var token = AcquireAccessToken( account, environment, - tenantId, + tenantIdOrName, password, promptBehavior, promptAction); @@ -317,7 +319,7 @@ public AzureRmProfile Login( if (shouldPopulateContextList && maxContextPopulation != 0) { var defaultContext = _profile.DefaultContext; - var subscriptions = maxContextPopulation > 0 ? ListSubscriptions(tenantId).Take(maxContextPopulation) : ListSubscriptions(tenantId); + var subscriptions = maxContextPopulation > 0 ? ListSubscriptions(tenantIdOrName).Take(maxContextPopulation) : ListSubscriptions(tenantIdOrName); foreach (var subscription in subscriptions) {