From 80b1f35f8bec00a573f8b47e9704734d1cd2e3a6 Mon Sep 17 00:00:00 2001 From: HandsomeJack Date: Mon, 15 Apr 2024 16:24:16 +0200 Subject: [PATCH] refactor(oauth): make WithTenant extensible with authority types --- .../internal/oauth/ops/authority/authority.go | 27 ++++++++++--------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/apps/internal/oauth/ops/authority/authority.go b/apps/internal/oauth/ops/authority/authority.go index 6be274f4..a9a70186 100644 --- a/apps/internal/oauth/ops/authority/authority.go +++ b/apps/internal/oauth/ops/authority/authority.go @@ -235,23 +235,24 @@ func NewAuthParams(clientID string, authorityInfo Info) AuthParams { // - the client is configured to authenticate only Microsoft accounts via the "consumers" endpoint // - the resulting authority URL is invalid func (p AuthParams) WithTenant(ID string) (AuthParams, error) { - switch ID { - case "", p.AuthorityInfo.Tenant: - // keep the default tenant because the caller didn't override it + if ID == "" || ID == p.AuthorityInfo.Tenant { return p, nil - case "common", "consumers", "organizations": - if p.AuthorityInfo.AuthorityType == AAD { + } + + var authority string + switch p.AuthorityInfo.AuthorityType { + case AAD: + if ID == "common" || ID == "consumers" || ID == "organizations" { return p, fmt.Errorf(`tenant ID must be a specific tenant, not "%s"`, ID) } - // else we'll return a better error below - } - if p.AuthorityInfo.AuthorityType != AAD { - return p, errors.New("the authority doesn't support tenants") - } - if p.AuthorityInfo.Tenant == "consumers" { - return p, errors.New(`client is configured to authenticate only personal Microsoft accounts, via the "consumers" endpoint`) + if p.AuthorityInfo.Tenant == "consumers" { + return p, errors.New(`client is configured to authenticate only personal Microsoft accounts, via the "consumers" endpoint`) + } + authority = "https://" + path.Join(p.AuthorityInfo.Host, ID) + case ADFS: + return p, errors.New("ADFS authority doesn't support tenants") } - authority := "https://" + path.Join(p.AuthorityInfo.Host, ID) + info, err := NewInfoFromAuthorityURI(authority, p.AuthorityInfo.ValidateAuthority, p.AuthorityInfo.InstanceDiscoveryDisabled) if err == nil { info.Region = p.AuthorityInfo.Region