diff --git a/tools/azure-sdk-tools/devtools_testutils/__init__.py b/tools/azure-sdk-tools/devtools_testutils/__init__.py index 02d2e7791178d..44a865a73d921 100644 --- a/tools/azure-sdk-tools/devtools_testutils/__init__.py +++ b/tools/azure-sdk-tools/devtools_testutils/__init__.py @@ -19,7 +19,7 @@ from .envvariable_loader import EnvironmentVariableLoader PowerShellPreparer = EnvironmentVariableLoader # Backward compat from .proxy_startup import start_test_proxy, stop_test_proxy, test_proxy -from .proxy_testcase import recorded_by_proxy +from .proxy_testcase import recorded_by_proxy, recorded_test from .sanitizers import ( add_body_key_sanitizer, add_body_regex_sanitizer, @@ -66,6 +66,7 @@ "PowerShellPreparer", "EnvironmentVariableLoader", "recorded_by_proxy", + "recorded_test", "test_proxy", "set_bodiless_matcher", "set_custom_default_matcher", diff --git a/tools/azure-sdk-tools/devtools_testutils/proxy_testcase.py b/tools/azure-sdk-tools/devtools_testutils/proxy_testcase.py index 88b2a1f311f45..25c4af56580d6 100644 --- a/tools/azure-sdk-tools/devtools_testutils/proxy_testcase.py +++ b/tools/azure-sdk-tools/devtools_testutils/proxy_testcase.py @@ -197,3 +197,44 @@ def combined_call(*args, **kwargs): return test_output return record_wrap + + +@pytest.fixture +def recorded_test(request): + if sys.version_info.major == 2 and not is_live(): + pytest.skip("Playback testing is incompatible with the azure-sdk-tools test proxy on Python 2") + + def transform_args(*args, **kwargs): + copied_positional_args = list(args) + request = copied_positional_args[1] + + transform_request(request, recording_id) + + return tuple(copied_positional_args), kwargs + + if is_live_and_not_recording(): + return + + test_id = get_test_id() + recording_id, variables = start_record_or_playback(test_id) + original_transport_func = RequestsTransport.send + + def combined_call(*args, **kwargs): + adjusted_args, adjusted_kwargs = transform_args(*args, **kwargs) + result = original_transport_func(*adjusted_args, **adjusted_kwargs) + + # make the x-recording-upstream-base-uri the URL of the request + # this makes the request look like it was made to the original endpoint instead of to the proxy + # without this, things like LROPollers can get broken by polling the wrong endpoint + parsed_result = url_parse.urlparse(result.request.url) + upstream_uri = url_parse.urlparse(result.request.headers["x-recording-upstream-base-uri"]) + upstream_uri_dict = {"scheme": upstream_uri.scheme, "netloc": upstream_uri.netloc} + original_target = parsed_result._replace(**upstream_uri_dict).geturl() + + result.request.url = original_target + return result + + RequestsTransport.send = combined_call + yield # test gets run here + RequestsTransport.send = original_transport_func # test finished running -- tear down + stop_record_or_playback(test_id, recording_id, None) # TODO: how do we provide variables to record?