Skip to content

Commit

Permalink
Allow requests to opt out of HTTPS enforcement (#9821)
Browse files Browse the repository at this point in the history
* 'enforce_https' kwarg controls https enforcement

* preserve enforcement opt out in request context

* remove usage of asyncio.coroutine

* correct parameter name for Mock 3.5
  • Loading branch information
chlowell authored Mar 20, 2020
1 parent 7c6f586 commit 81eb35b
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,19 @@ def __init__(self, credential, *scopes, **kwargs): # pylint:disable=unused-argu
self._token = None # type: Optional[AccessToken]

@staticmethod
def _enforce_tls(request):
def _enforce_https(request):
# type: (PipelineRequest) -> None
if not request.http_request.url.lower().startswith("https"):

# move 'enforce_https' from options to context so it persists
# across retries but isn't passed to a transport implementation
option = request.context.options.pop("enforce_https", None)

# True is the default setting; we needn't preserve an explicit opt in to the default behavior
if option is False:
request.context["enforce_https"] = option

enforce_https = request.context.get("enforce_https", True)
if enforce_https and not request.http_request.url.lower().startswith("https"):
raise ServiceRequestError(
"Bearer token authentication is not permitted for non-TLS protected (non-https) URLs."
)
Expand Down Expand Up @@ -76,7 +86,7 @@ def on_request(self, request):
:param request: The pipeline request object
:type request: ~azure.core.pipeline.PipelineRequest
"""
self._enforce_tls(request)
self._enforce_https(request)

if self._need_new_token:
self._token = self._credential.get_token(*self._scopes)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ async def on_request(self, request: PipelineRequest):
:type request: ~azure.core.pipeline.PipelineRequest
:raises: :class:`~azure.core.exceptions.ServiceRequestError`
"""
self._enforce_tls(request)
self._enforce_https(request)

with self._lock:
if self._need_new_token:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
from unittest.mock import Mock

from azure.core.credentials import AccessToken
from azure.core.exceptions import ServiceRequestError
from azure.core.exceptions import AzureError, ServiceRequestError
from azure.core.pipeline import AsyncPipeline
from azure.core.pipeline.policies import AsyncBearerTokenCredentialPolicy
from azure.core.pipeline.policies import AsyncBearerTokenCredentialPolicy, SansIOHTTPPolicy
from azure.core.pipeline.transport import HttpRequest
import pytest

Expand Down Expand Up @@ -48,7 +48,7 @@ async def verify_request(request):
assert request.http_request is expected_request
return expected_response

fake_credential = Mock(get_token=asyncio.coroutine(lambda _: AccessToken("", 0)))
fake_credential = Mock(get_token=lambda *_, **__: get_completed_future(AccessToken("", 0)))
policies = [AsyncBearerTokenCredentialPolicy(fake_credential, "scope"), Mock(send=verify_request)]
response = await AsyncPipeline(transport=Mock(), policies=policies).run(expected_request)

Expand All @@ -67,7 +67,10 @@ async def get_token(_):
return expected_token

credential = Mock(get_token=get_token)
policies = [AsyncBearerTokenCredentialPolicy(credential, "scope"), Mock(send=asyncio.coroutine(lambda _: Mock()))]
policies = [
AsyncBearerTokenCredentialPolicy(credential, "scope"),
Mock(send=Mock(return_value=get_completed_future())),
]
pipeline = AsyncPipeline(transport=Mock, policies=policies)

await pipeline.run(HttpRequest("GET", "https://spam.eggs"))
Expand All @@ -79,7 +82,7 @@ async def get_token(_):
expired_token = AccessToken("token", time.time())
get_token_calls = 0
expected_token = expired_token
policies = [AsyncBearerTokenCredentialPolicy(credential, "scope"), Mock(send=asyncio.coroutine(lambda _: Mock()))]
policies = [AsyncBearerTokenCredentialPolicy(credential, "scope"), Mock(send=lambda _: get_completed_future())]
pipeline = AsyncPipeline(transport=Mock(), policies=policies)

await pipeline.run(HttpRequest("GET", "https://spam.eggs"))
Expand All @@ -90,8 +93,65 @@ async def get_token(_):


@pytest.mark.asyncio
async def test_bearer_policy_enforces_tls():
credential = Mock()
pipeline = AsyncPipeline(transport=Mock(), policies=[AsyncBearerTokenCredentialPolicy(credential, "scope")])
async def test_bearer_policy_optionally_enforces_https():
"""HTTPS enforcement should be controlled by a keyword argument, and enabled by default"""

async def assert_option_popped(request, **kwargs):
assert "enforce_https" not in kwargs, "AsyncBearerTokenCredentialPolicy didn't pop the 'enforce_https' option"

credential = Mock(get_token=lambda *_, **__: get_completed_future(AccessToken("***", 42)))
pipeline = AsyncPipeline(
transport=Mock(send=assert_option_popped), policies=[AsyncBearerTokenCredentialPolicy(credential, "scope")]
)

# by default and when enforce_https=True, the policy should raise when given an insecure request
with pytest.raises(ServiceRequestError):
await pipeline.run(HttpRequest("GET", "http://not.secure"))
with pytest.raises(ServiceRequestError):
await pipeline.run(HttpRequest("GET", "http://not.secure"), enforce_https=True)

# when enforce_https=False, an insecure request should pass
await pipeline.run(HttpRequest("GET", "http://not.secure"), enforce_https=False)

# https requests should always pass
await pipeline.run(HttpRequest("GET", "https://secure"), enforce_https=False)
await pipeline.run(HttpRequest("GET", "https://secure"), enforce_https=True)
await pipeline.run(HttpRequest("GET", "https://secure"))


@pytest.mark.asyncio
async def test_preserves_enforce_https_opt_out():
"""The policy should use request context to preserve an opt out from https enforcement"""

class ContextValidator(SansIOHTTPPolicy):
def on_request(self, request):
assert "enforce_https" in request.context, "'enforce_https' is not in the request's context"

get_token = get_completed_future(AccessToken("***", 42))
credential = Mock(get_token=lambda *_, **__: get_token)
policies = [AsyncBearerTokenCredentialPolicy(credential, "scope"), ContextValidator()]
pipeline = AsyncPipeline(transport=Mock(send=lambda *_, **__: get_completed_future()), policies=policies)

await pipeline.run(HttpRequest("GET", "http://not.secure"), enforce_https=False)


@pytest.mark.asyncio
async def test_context_unmodified_by_default():
"""When no options for the policy accompany a request, the policy shouldn't add anything to the request context"""

class ContextValidator(SansIOHTTPPolicy):
def on_request(self, request):
assert not any(request.context), "the policy shouldn't add to the request's context"

get_token = get_completed_future(AccessToken("***", 42))
credential = Mock(get_token=lambda *_, **__: get_token)
policies = [AsyncBearerTokenCredentialPolicy(credential, "scope"), ContextValidator()]
pipeline = AsyncPipeline(transport=Mock(send=lambda *_, **__: get_completed_future()), policies=policies)

await pipeline.run(HttpRequest("GET", "https://secure"))


def get_completed_future(result=None):
fut = asyncio.Future()
fut.set_result(result)
return fut
53 changes: 49 additions & 4 deletions sdk/core/azure-core/tests/test_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from azure.core.credentials import AccessToken
from azure.core.exceptions import ServiceRequestError
from azure.core.pipeline import Pipeline
from azure.core.pipeline.policies import BearerTokenCredentialPolicy
from azure.core.pipeline.policies import BearerTokenCredentialPolicy, SansIOHTTPPolicy
from azure.core.pipeline.transport import HttpRequest

import pytest
Expand Down Expand Up @@ -75,11 +75,56 @@ def test_bearer_policy_token_caching():
assert credential.get_token.call_count == 2 # token expired -> policy should call get_token


def test_bearer_policy_enforces_tls():
credential = Mock()
pipeline = Pipeline(transport=Mock(), policies=[BearerTokenCredentialPolicy(credential, "scope")])
def test_bearer_policy_optionally_enforces_https():
"""HTTPS enforcement should be controlled by a keyword argument, and enabled by default"""

def assert_option_popped(request, **kwargs):
assert "enforce_https" not in kwargs, "BearerTokenCredentialPolicy didn't pop the 'enforce_https' option"

credential = Mock(get_token=lambda *_, **__: AccessToken("***", 42))
pipeline = Pipeline(
transport=Mock(send=assert_option_popped), policies=[BearerTokenCredentialPolicy(credential, "scope")]
)

# by default and when enforce_https=True, the policy should raise when given an insecure request
with pytest.raises(ServiceRequestError):
pipeline.run(HttpRequest("GET", "http://not.secure"))
with pytest.raises(ServiceRequestError):
pipeline.run(HttpRequest("GET", "http://not.secure"), enforce_https=True)

# when enforce_https=False, an insecure request should pass
pipeline.run(HttpRequest("GET", "http://not.secure"), enforce_https=False)

# https requests should always pass
pipeline.run(HttpRequest("GET", "https://secure"), enforce_https=False)
pipeline.run(HttpRequest("GET", "https://secure"), enforce_https=True)
pipeline.run(HttpRequest("GET", "https://secure"))


def test_preserves_enforce_https_opt_out():
"""The policy should use request context to preserve an opt out from https enforcement"""

class ContextValidator(SansIOHTTPPolicy):
def on_request(self, request):
assert "enforce_https" in request.context, "'enforce_https' is not in the request's context"

policies = [BearerTokenCredentialPolicy(credential=Mock(), scope="scope"), ContextValidator()]
pipeline = Pipeline(transport=Mock(), policies=policies)

pipeline.run(HttpRequest("GET", "http://not.secure"), enforce_https=False)


def test_context_unmodified_by_default():
"""When no options for the policy accompany a request, the policy shouldn't add anything to the request context"""

class ContextValidator(SansIOHTTPPolicy):
def on_request(self, request):
assert not any(request.context), "the policy shouldn't add to the request's context"

policies = [BearerTokenCredentialPolicy(credential=Mock(), scope="scope"), ContextValidator()]
pipeline = Pipeline(transport=Mock(), policies=policies)

pipeline.run(HttpRequest("GET", "https://secure"))


@pytest.mark.skipif(azure.core.__version__ >= "2", reason="this test applies only to azure-core 1.x")
Expand Down

0 comments on commit 81eb35b

Please sign in to comment.