Skip to content

Commit

Permalink
SingleTenant support (#2055)
Browse files Browse the repository at this point in the history
* SingleTenant support

* Pylint corrections

* Black correction

* SingleTenant Gov correction

* Ported Gov SingleTenant fixes from DotNet

* Black correction

* Pylint corrections

* Black corrections for app_credentials

* Corrected AppCredentials._should_set_token

* Changed auth constant to match setting name

* black corrections

---------

Co-authored-by: Tracy Boehrer <[email protected]>
  • Loading branch information
tracyboehrer and Tracy Boehrer authored Feb 13, 2024
1 parent 184a2da commit a078027
Show file tree
Hide file tree
Showing 25 changed files with 234 additions and 233 deletions.
8 changes: 8 additions & 0 deletions doc/SkillClaimsValidation.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,11 @@ ADAPTER = BotFrameworkAdapter(
SETTINGS,
)
```

For SingleTenant type bots, the additional issuers must be added based on the tenant id:
```python
AUTH_CONFIG = AuthenticationConfiguration(
claims_validator=AllowedSkillsClaimsValidator(CONFIG).claims_validator,
tenant_id=the_tenant_id
)
```
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import requests
import pytest

SKIP = os.getenv("SlackChannel") == ''
SKIP = os.getenv("SlackChannel") == ""


class SlackClient(aiounittest.AsyncTestCase):
Expand Down
13 changes: 0 additions & 13 deletions libraries/botbuilder-core/botbuilder/core/bot_framework_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,19 +279,6 @@ async def continue_conversation(
context.turn_state[BotAdapter.BOT_CALLBACK_HANDLER_KEY] = callback
context.turn_state[BotAdapter.BOT_OAUTH_SCOPE_KEY] = audience

# If we receive a valid app id in the incoming token claims, add the channel service URL to the
# trusted services list so we can send messages back.
# The service URL for skills is trusted because it is applied by the SkillHandler based on the original
# request received by the root bot
app_id_from_claims = JwtTokenValidation.get_app_id_from_claims(
claims_identity.claims
)
if app_id_from_claims:
if SkillValidation.is_skill_claim(
claims_identity.claims
) or await self._credential_provider.is_valid_appid(app_id_from_claims):
AppCredentials.trust_service_url(reference.service_url)

client = await self.create_connector_client(
reference.service_url, claims_identity, audience
)
Expand Down

This file was deleted.

12 changes: 0 additions & 12 deletions libraries/botbuilder-core/tests/test_bot_framework_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,14 +621,8 @@ async def callback(context: TurnContext):
scope = context.turn_state[BotFrameworkAdapter.BOT_OAUTH_SCOPE_KEY]
assert AuthenticationConstants.TO_CHANNEL_FROM_BOT_OAUTH_SCOPE == scope

# Ensure the serviceUrl was added to the trusted hosts
assert AppCredentials.is_trusted_service(channel_service_url)

refs = ConversationReference(service_url=channel_service_url)

# Ensure the serviceUrl is NOT in the trusted hosts
assert not AppCredentials.is_trusted_service(channel_service_url)

await adapter.continue_conversation(
refs, callback, claims_identity=skills_identity
)
Expand Down Expand Up @@ -694,14 +688,8 @@ async def callback(context: TurnContext):
scope = context.turn_state[BotFrameworkAdapter.BOT_OAUTH_SCOPE_KEY]
assert skill_2_app_id == scope

# Ensure the serviceUrl was added to the trusted hosts
assert AppCredentials.is_trusted_service(skill_2_service_url)

refs = ConversationReference(service_url=skill_2_service_url)

# Ensure the serviceUrl is NOT in the trusted hosts
assert not AppCredentials.is_trusted_service(skill_2_service_url)

await adapter.continue_conversation(
refs, callback, claims_identity=skills_identity, audience=skill_2_app_id
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,12 +188,6 @@ async def _http_authenticate_request(self, request: Request) -> bool:
)
)

# Add ServiceURL to the cache of trusted sites in order to allow token refreshing.
self._credentials.trust_service_url(
claims_identity.claims.get(
AuthenticationConstants.SERVICE_URL_CLAIM
)
)
self.claims_identity = claims_identity
return True
except Exception as error:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,38 @@ class ConfigurationServiceClientCredentialFactory(
PasswordServiceClientCredentialFactory
):
def __init__(self, configuration: Any, *, logger: Logger = None) -> None:
super().__init__(
app_id=getattr(configuration, "APP_ID", None),
password=getattr(configuration, "APP_PASSWORD", None),
logger=logger,
app_type = (
configuration.APP_TYPE
if hasattr(configuration, "APP_TYPE")
else "MultiTenant"
)
app_id = configuration.APP_ID if hasattr(configuration, "APP_ID") else None
app_password = (
configuration.APP_PASSWORD
if hasattr(configuration, "APP_PASSWORD")
else None
)
app_tenantid = None

if app_type == "UserAssignedMsi":
raise Exception("UserAssignedMsi APP_TYPE is not supported")

if app_type == "SingleTenant":
app_tenantid = (
configuration.APP_TENANTID
if hasattr(configuration, "APP_TENANTID")
else None
)

if not app_id:
raise Exception("Property 'APP_ID' is expected in configuration object")
if not app_password:
raise Exception(
"Property 'APP_PASSWORD' is expected in configuration object"
)
if not app_tenantid:
raise Exception(
"Property 'APP_TENANTID' is expected in configuration object"
)

super().__init__(app_id, app_password, app_tenantid, logger=logger)
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,6 @@ async def post_activity(
conversation_id: str,
activity: Activity,
) -> InvokeResponse:
if not from_bot_id:
raise TypeError("from_bot_id")
if not to_bot_id:
raise TypeError("to_bot_id")
if not to_url:
raise TypeError("to_url")
if not service_url:
Expand Down Expand Up @@ -100,6 +96,7 @@ async def post_activity(

headers_dict = {
"Content-type": "application/json; charset=utf-8",
"x-ms-conversation-id": conversation_id,
}
if token:
headers_dict.update(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ async def create_user_token_client(

credentials = await self._credentials_factory.create_credentials(
app_id,
audience=self._to_channel_from_bot_oauth_scope,
oauth_scope=self._to_channel_from_bot_oauth_scope,
login_endpoint=self._login_endpoint,
validate_authority=True,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(
):
super(_GovernmentCloudBotFrameworkAuthentication, self).__init__(
GovernmentConstants.TO_CHANNEL_FROM_BOT_OAUTH_SCOPE,
GovernmentConstants.TO_CHANNEL_FROM_BOT_LOGIN_URL,
GovernmentConstants.TO_CHANNEL_FROM_BOT_LOGIN_URL_PREFIX,
CallerIdConstants.us_gov_channel,
GovernmentConstants.CHANNEL_SERVICE,
GovernmentConstants.OAUTH_URL_GOV,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ async def create_user_token_client(

credentials = await self._credentials_factory.create_credentials(
app_id,
audience=self._to_channel_from_bot_oauth_scope,
oauth_scope=self._to_channel_from_bot_oauth_scope,
login_endpoint=self._to_channel_from_bot_login_url,
validate_authority=self._validate_authority,
)
Expand Down Expand Up @@ -274,6 +274,11 @@ async def _skill_validation_authenticate_channel_token(
ignore_expiration=False,
)

if self._auth_configuration.valid_token_issuers:
validation_params.issuer.append(
self._auth_configuration.valid_token_issuers
)

# TODO: what should the openIdMetadataUrl be here?
token_extractor = JwtTokenExtractor(
validation_params,
Expand Down Expand Up @@ -362,6 +367,11 @@ async def _emulator_validation_authenticate_emulator_token(
ignore_expiration=False,
)

if self._auth_configuration.valid_token_issuers:
to_bot_from_emulator_validation_params.issuer.append(
self._auth_configuration.valid_token_issuers
)

token_extractor = JwtTokenExtractor(
to_bot_from_emulator_validation_params,
metadata_url=self._to_bot_from_emulator_open_id_metadata_url,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

from datetime import datetime, timedelta
from urllib.parse import urlparse

import requests
from msrest.authentication import Authentication

from botframework.connector.auth import AuthenticationConstants
from .authentication_constants import AuthenticationConstants


class AppCredentials(Authentication):
Expand All @@ -17,16 +14,8 @@ class AppCredentials(Authentication):
"""

schema = "Bearer"

trustedHostNames = {
# "state.botframework.com": datetime.max,
# "state.botframework.azure.us": datetime.max,
"api.botframework.com": datetime.max,
"token.botframework.com": datetime.max,
"api.botframework.azure.us": datetime.max,
"token.botframework.azure.us": datetime.max,
}
cache = {}
__tenant = None

def __init__(
self,
Expand All @@ -38,50 +27,55 @@ def __init__(
Initializes a new instance of MicrosoftAppCredentials class
:param channel_auth_tenant: Optional. The oauth token tenant.
"""
tenant = (
channel_auth_tenant
if channel_auth_tenant
else AuthenticationConstants.DEFAULT_CHANNEL_AUTH_TENANT
)
self.microsoft_app_id = app_id
self.tenant = channel_auth_tenant
self.oauth_endpoint = (
AuthenticationConstants.TO_CHANNEL_FROM_BOT_LOGIN_URL_PREFIX + tenant
)
self.oauth_scope = (
oauth_scope or AuthenticationConstants.TO_CHANNEL_FROM_BOT_OAUTH_SCOPE
self._get_to_channel_from_bot_loginurl_prefix() + self.tenant
)
self.oauth_scope = oauth_scope or self._get_to_channel_from_bot_oauthscope()

self.microsoft_app_id = app_id
def _get_default_channelauth_tenant(self) -> str:
return AuthenticationConstants.DEFAULT_CHANNEL_AUTH_TENANT

def _get_to_channel_from_bot_loginurl_prefix(self) -> str:
return AuthenticationConstants.TO_CHANNEL_FROM_BOT_LOGIN_URL_PREFIX

def _get_to_channel_from_bot_oauthscope(self) -> str:
return AuthenticationConstants.TO_CHANNEL_FROM_BOT_OAUTH_SCOPE

@property
def tenant(self) -> str:
return self.__tenant

@tenant.setter
def tenant(self, value: str):
self.__tenant = value or self._get_default_channelauth_tenant()

@staticmethod
def trust_service_url(service_url: str, expiration=None):
"""
Obsolete: trust_service_url is not a required part of the security model.
Checks if the service url is for a trusted host or not.
:param service_url: The service url.
:param expiration: The expiration time after which this service url is not trusted anymore.
:returns: True if the host of the service url is trusted; False otherwise.
"""
if expiration is None:
expiration = datetime.now() + timedelta(days=1)
host = urlparse(service_url).hostname
if host is not None:
AppCredentials.trustedHostNames[host] = expiration

@staticmethod
def is_trusted_service(service_url: str) -> bool:
def is_trusted_service(service_url: str) -> bool: # pylint: disable=unused-argument
"""
Obsolete: is_trusted_service is not a required part of the security model.
Checks if the service url is for a trusted host or not.
:param service_url: The service url.
:returns: True if the host of the service url is trusted; False otherwise.
"""
host = urlparse(service_url).hostname
if host is not None:
return AppCredentials._is_trusted_url(host)
return False
return True

@staticmethod
def _is_trusted_url(host: str) -> bool:
expiration = AppCredentials.trustedHostNames.get(host, datetime.min)
return expiration > (datetime.now() - timedelta(minutes=5))
def _is_trusted_url(host: str) -> bool: # pylint: disable=unused-argument
"""
Obsolete: _is_trusted_url is not a required part of the security model.
"""
return True

# pylint: disable=arguments-differ
def signed_session(self, session: requests.Session = None) -> requests.Session:
Expand All @@ -92,7 +86,7 @@ def signed_session(self, session: requests.Session = None) -> requests.Session:
if not session:
session = requests.Session()

if not self._should_authorize(session):
if not self._should_set_token(session):
session.headers.pop("Authorization", None)
else:
auth_token = self.get_access_token()
Expand All @@ -101,13 +95,13 @@ def signed_session(self, session: requests.Session = None) -> requests.Session:

return session

def _should_authorize(
def _should_set_token(
self, session: requests.Session # pylint: disable=unused-argument
) -> bool:
# We don't set the token if the AppId is not set, since it means that we are in an un-authenticated scenario.
return (
self.microsoft_app_id != AuthenticationConstants.ANONYMOUS_SKILL_APP_ID
and self.microsoft_app_id is not None
and self.microsoft_app_id
)

def get_access_token(self, force_refresh: bool = False) -> str:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,39 @@

from typing import Awaitable, Callable, Dict, List

from .authentication_constants import AuthenticationConstants


class AuthenticationConfiguration:
def __init__(
self,
required_endorsements: List[str] = None,
claims_validator: Callable[[List[Dict]], Awaitable] = None,
valid_token_issuers: List[str] = None,
tenant_id: str = None,
):
self.required_endorsements = required_endorsements or []
self.claims_validator = claims_validator
self.valid_token_issuers = valid_token_issuers or []

if tenant_id:
self.add_tenant_issuers(self, tenant_id)

@staticmethod
def add_tenant_issuers(authentication_configuration, tenant_id: str):
authentication_configuration.valid_token_issuers.append(
AuthenticationConstants.VALID_TOKEN_ISSUER_URL_TEMPLATE_V1.format(tenant_id)
)
authentication_configuration.valid_token_issuers.append(
AuthenticationConstants.VALID_TOKEN_ISSUER_URL_TEMPLATE_V2.format(tenant_id)
)
authentication_configuration.valid_token_issuers.append(
AuthenticationConstants.VALID_GOVERNMENT_TOKEN_ISSUER_URL_TEMPLATE_V1.format(
tenant_id
)
)
authentication_configuration.valid_token_issuers.append(
AuthenticationConstants.VALID_GOVERNMENT_TOKEN_ISSUER_URL_TEMPLATE_V2.format(
tenant_id
)
)
Loading

0 comments on commit a078027

Please sign in to comment.