diff --git a/docs/source/garak.generators.rest.rst b/docs/source/garak.generators.rest.rst index 6a789ea7..93a58ce8 100644 --- a/docs/source/garak.generators.rest.rst +++ b/docs/source/garak.generators.rest.rst @@ -18,6 +18,7 @@ Uses the following options from ``_config.plugins.generators["rest.RestGenerator * ``request_timeout`` - How many seconds should we wait before timing out? Default 20 * ``ratelimit_codes`` - Which endpoint HTTP response codes should be caught as indicative of rate limiting and retried? ``List[int]``, default ``[429]`` * ``skip_codes`` - Which endpoint HTTP response code should lead to the generation being treated as not possible and skipped for this query. Takes precedence over ``ratelimit_codes``. +* ``verify_ssl`` - (optional) Enforce ssl certificate validation? Default is ``True``, a file path to a CA bundle can be provided. (bool|str) Templates can be either a string or a JSON-serialisable Python object. Instance of ``$INPUT`` here are replaced with the prompt; instances of ``$KEY`` diff --git a/garak/generators/rest.py b/garak/generators/rest.py index 6516a8f1..b65e3da7 100644 --- a/garak/generators/rest.py +++ b/garak/generators/rest.py @@ -36,6 +36,7 @@ class RestGenerator(Generator): "req_template": "$INPUT", "request_timeout": 20, "proxies": None, + "verify_ssl": True, } ENV_VAR = "REST_API_KEY" @@ -61,6 +62,7 @@ class RestGenerator(Generator): "temperature", "top_k", "proxies", + "verify_ssl", ) def __init__(self, uri=None, config_root=_config): @@ -128,6 +130,10 @@ def __init__(self, uri=None, config_root=_config): "`proxies` value provided is not in the required format. See documentation from the `requests` package for details on expected format. https://requests.readthedocs.io/en/latest/user/advanced/#proxies" ) + # suppress warnings about intentional SSL validation suppression + if isinstance(self.verify_ssl, bool) and not self.verify_ssl: + requests.packages.urllib3.disable_warnings() + # validate jsonpath if self.response_json and self.response_json_field: try: @@ -204,6 +210,7 @@ def _call_model( "headers": request_headers, "timeout": self.request_timeout, "proxies": self.proxies, + "verify": self.verify_ssl, } resp = self.http_function(self.uri, **req_kArgs) diff --git a/tests/generators/test_rest.py b/tests/generators/test_rest.py index 55aa9d12..c0486ef1 100644 --- a/tests/generators/test_rest.py +++ b/tests/generators/test_rest.py @@ -168,3 +168,37 @@ def test_rest_invalid_proxy(requests_mock): with pytest.raises(GarakException) as exc_info: _plugins.load_plugin("generators.rest.RestGenerator", config_root=_config) assert "not in the required format" in str(exc_info.value) + + +@pytest.mark.usefixtures("set_rest_config") +@pytest.mark.parametrize("verify_ssl", (True, False, None)) +def test_rest_ssl_suppression(mocker, requests_mock, verify_ssl): + if verify_ssl is not None: + _config.plugins.generators["rest"]["RestGenerator"]["verify_ssl"] = verify_ssl + else: + verify_ssl = RestGenerator.DEFAULT_PARAMS["verify_ssl"] + generator = _plugins.load_plugin( + "generators.rest.RestGenerator", config_root=_config + ) + requests_mock.post( + DEFAULT_URI, + text=json.dumps( + { + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": DEFAULT_TEXT_RESPONSE, + }, + } + ] + } + ), + ) + mock_http_function = mocker.patch.object( + generator, "http_function", wraps=generator.http_function + ) + generator._call_model("Who is Enabran Tain's son?") + mock_http_function.assert_called_once() + assert mock_http_function.call_args_list[0].kwargs["verify"] is verify_ssl