From 9d4063ef0b8c5b9f81e649c85e323104677e1673 Mon Sep 17 00:00:00 2001 From: Mario Buikhuizen Date: Fri, 18 Mar 2022 10:21:51 +0100 Subject: [PATCH] feat: make request info available in env for preheated kernels (#1109) * feat: make request info available in env for preheated kernels The mechanism that was used for 'get_query_string()' is changed so that all request info is sent. For backward compatibility 'get_query_string()' is reimplemented by using the changed code. * refactor: no need to make a copy of environ in this case Co-authored-by: Maarten Breddels * docs: add documentation for wait_for_request * refactor: remove backward compatibility for get_query_string Co-authored-by: Maarten Breddels --- docs/source/customize.rst | 10 ++++--- voila/app.py | 4 +-- voila/handler.py | 23 ++++++++-------- ...ers_handler.py => request_info_handler.py} | 22 ++++++++-------- voila/utils.py | 26 ++++++++++--------- 5 files changed, 45 insertions(+), 40 deletions(-) rename voila/{query_parameters_handler.py => request_info_handler.py} (73%) diff --git a/docs/source/customize.rst b/docs/source/customize.rst index 319e76368..05b828d8d 100755 --- a/docs/source/customize.rst +++ b/docs/source/customize.rst @@ -399,14 +399,16 @@ In normal mode, Voilà users can get the `query string` at run time through the import os query_string = os.getenv('QUERY_STRING') -In preheating kernel mode, users can just replace the ``os.getenv`` call with the helper ``get_query_string`` from ``voila.utils`` +In preheating kernel mode, users can prepend with ``wait_for_request`` from ``voila.utils`` .. code-block:: python - from voila.utils import get_query_string - query_string = get_query_string() + import os + from voila.utils import wait_for_request + wait_for_request() + query_string = os.getenv('QUERY_STRING') -``get_query_string`` will pause the execution of the notebook in the preheated kernel at this cell and wait for an actual user to connect to Voilà, then ``get_query_string`` will return the URL `query string` and continue the execution of the remaining cells. +``wait_for_request`` will pause the execution of the notebook in the preheated kernel at this cell and wait for an actual user to connect to Voilà, set the request info environment variables and then continue the execution of the remaining cells. If the Voilà websocket handler is not started with the default protocol (`ws`), the default IP address (`127.0.0.1`) or the default port (`8866`), users need to provide these values through the environment variables ``VOILA_APP_PROTOCOL``, ``VOILA_APP_IP`` and ``VOILA_APP_PORT``. The easiest way is to set these variables in the `voila.json` configuration file, for example: diff --git a/voila/app.py b/voila/app.py index fd92f78cd..841134f70 100644 --- a/voila/app.py +++ b/voila/app.py @@ -62,7 +62,7 @@ from .exporter import VoilaExporter from .shutdown_kernel_handler import VoilaShutdownKernelHandler from .voila_kernel_manager import voila_kernel_manager_factory -from .query_parameters_handler import QueryStringSocketHandler +from .request_info_handler import RequestInfoSocketHandler from .utils import create_include_assets_functions _kernel_id_regex = r"(?P\w+-\w+-\w+-\w+-\w+)" @@ -500,7 +500,7 @@ def start(self): handlers.append( ( url_path_join(self.server_url, r'/voila/query/%s' % _kernel_id_regex), - QueryStringSocketHandler + RequestInfoSocketHandler ) ) # Serving notebook extensions diff --git a/voila/handler.py b/voila/handler.py index 93de9f1dc..3b2c6d614 100644 --- a/voila/handler.py +++ b/voila/handler.py @@ -22,7 +22,7 @@ from ._version import __version__ from .notebook_renderer import NotebookRenderer -from .query_parameters_handler import QueryStringSocketHandler +from .request_info_handler import RequestInfoSocketHandler from .utils import ENV_VARIABLE, create_include_assets_functions @@ -80,17 +80,17 @@ async def get_generator(self, path=None): cwd = os.path.dirname(notebook_path) # Adding request uri to kernel env - kernel_env = os.environ.copy() - kernel_env[ENV_VARIABLE.SCRIPT_NAME] = self.request.path - kernel_env[ + request_info = dict() + request_info[ENV_VARIABLE.SCRIPT_NAME] = self.request.path + request_info[ ENV_VARIABLE.PATH_INFO ] = '' # would be /foo/bar if voila.ipynb/foo/bar was supported - kernel_env[ENV_VARIABLE.QUERY_STRING] = str(self.request.query) - kernel_env[ENV_VARIABLE.SERVER_SOFTWARE] = 'voila/{}'.format(__version__) - kernel_env[ENV_VARIABLE.SERVER_PROTOCOL] = str(self.request.version) + request_info[ENV_VARIABLE.QUERY_STRING] = str(self.request.query) + request_info[ENV_VARIABLE.SERVER_SOFTWARE] = 'voila/{}'.format(__version__) + request_info[ENV_VARIABLE.SERVER_PROTOCOL] = str(self.request.version) host, port = split_host_and_port(self.request.host.lower()) - kernel_env[ENV_VARIABLE.SERVER_PORT] = str(port) if port else '' - kernel_env[ENV_VARIABLE.SERVER_NAME] = host + request_info[ENV_VARIABLE.SERVER_PORT] = str(port) if port else '' + request_info[ENV_VARIABLE.SERVER_NAME] = host # Add HTTP Headers as env vars following rfc3875#section-4.1.18 if len(self.voila_configuration.http_header_envs) > 0: for header_name in self.request.headers: @@ -98,7 +98,7 @@ async def get_generator(self, path=None): # Use case insensitive comparison of header names as per rfc2616#section-4.2 if header_name.lower() in config_headers_lower: env_name = f'HTTP_{header_name.upper().replace("-", "_")}' - kernel_env[env_name] = self.request.headers.get(header_name) + request_info[env_name] = self.request.headers.get(header_name) template_arg = self.get_argument("voila-template", None) theme_arg = self.get_argument("voila-theme", None) @@ -132,7 +132,7 @@ async def get_generator(self, path=None): notebook_name=notebook_path, ) - QueryStringSocketHandler.send_updates({'kernel_id': kernel_id, 'payload': self.request.query}) + RequestInfoSocketHandler.send_updates({'kernel_id': kernel_id, 'payload': request_info}) # Send rendered cell to frontend if len(rendered_cache) > 0: yield ''.join(rendered_cache) @@ -183,6 +183,7 @@ def time_out(): return '\n' + kernel_env = {**os.environ, **request_info} kernel_env[ENV_VARIABLE.VOILA_PREHEAT] = 'False' kernel_env[ENV_VARIABLE.VOILA_BASE_URL] = self.base_url kernel_id = await ensure_async( diff --git a/voila/query_parameters_handler.py b/voila/request_info_handler.py similarity index 73% rename from voila/query_parameters_handler.py rename to voila/request_info_handler.py index 5888e6080..422736816 100644 --- a/voila/query_parameters_handler.py +++ b/voila/request_info_handler.py @@ -3,8 +3,8 @@ from typing import Dict -class QueryStringSocketHandler(WebSocketHandler): - """A websocket handler used to provide the query string +class RequestInfoSocketHandler(WebSocketHandler): + """A websocket handler used to provide the request info assocciated with kernel ids in preheat kernel mode. Class variables @@ -12,7 +12,7 @@ class QueryStringSocketHandler(WebSocketHandler): - _waiters : A dictionary which holds the `websocket` connection assocciated with the kernel id. - - cache : A dictionary which holds the query string assocciated + - cache : A dictionary which holds the request info assocciated with the kernel id. """ _waiters = dict() @@ -26,28 +26,28 @@ def open(self, kernel_id: str) -> None: kernel_id (str): Kernel id used by the notebook when it opens the websocket connection. """ - QueryStringSocketHandler._waiters[kernel_id] = self + RequestInfoSocketHandler._waiters[kernel_id] = self if kernel_id in self._cache: self.write_message(self._cache[kernel_id]) def on_close(self) -> None: - for k_id, waiter in QueryStringSocketHandler._waiters.items(): + for k_id, waiter in RequestInfoSocketHandler._waiters.items(): if waiter == self: break - del QueryStringSocketHandler._waiters[k_id] + del RequestInfoSocketHandler._waiters[k_id] @classmethod - def send_updates(cls: 'QueryStringSocketHandler', msg: Dict) -> None: - """Class method used to dispath the query string to the waiting - notebook. This method is called in `VoilaHandler` when the query - string becomes available. + def send_updates(cls: 'RequestInfoSocketHandler', msg: Dict) -> None: + """Class method used to dispath the request info to the waiting + notebook. This method is called in `VoilaHandler` when the request + info becomes available. If this method is called before the opening of websocket connection, `msg` is stored in `_cache0` and the message will be dispatched when a notebook with coresponding kernel id is connected. Args: - msg (Dict): this dictionary contains the `kernel_id` to identify - the waiting notebook and `payload` is the query string. + the waiting notebook and `payload` is the request info. """ kernel_id = msg['kernel_id'] payload = msg['payload'] diff --git a/voila/utils.py b/voila/utils.py index da5f6ef7f..ebe15ffbe 100644 --- a/voila/utils.py +++ b/voila/utils.py @@ -13,6 +13,7 @@ import threading from enum import Enum from typing import Awaitable +import json import websockets @@ -58,29 +59,29 @@ def get_server_root_dir(settings): return root_dir -async def _get_query_string(ws_url: str) -> Awaitable: +async def _get_request_info(ws_url: str) -> Awaitable: async with websockets.connect(ws_url) as websocket: - qs = await websocket.recv() - return qs + ri = await websocket.recv() + return ri -def get_query_string(url: str = None) -> str: +def wait_for_request(url: str = None) -> str: """Helper function to pause the execution of notebook and wait for - the query string. + the pre-heated kernel to be used and all request info is added to + the environment. Args: - url (str, optional): Address to get user query string, if it is not + url (str, optional): Address to get request info, if it is not provided, `voila` will figure out from the environment variables. Defaults to None. - Returns: The query string provided by `QueryStringSocketHandler`. """ preheat_mode = os.getenv(ENV_VARIABLE.VOILA_PREHEAT, 'False') if preheat_mode == 'False': - return os.getenv(ENV_VARIABLE.QUERY_STRING) + return - query_string = None + request_info = None if url is None: protocol = os.getenv(ENV_VARIABLE.VOILA_APP_PROTOCOL, 'ws') server_ip = os.getenv(ENV_VARIABLE.VOILA_APP_IP, '127.0.0.1') @@ -92,9 +93,9 @@ def get_query_string(url: str = None) -> str: ws_url = f'{url}/{kernel_id}' def inner(): - nonlocal query_string + nonlocal request_info loop = asyncio.new_event_loop() - query_string = loop.run_until_complete(_get_query_string(ws_url)) + request_info = loop.run_until_complete(_get_request_info(ws_url)) thread = threading.Thread(target=inner) try: @@ -103,7 +104,8 @@ def inner(): except (KeyboardInterrupt, SystemExit): asyncio.get_event_loop().stop() - return query_string + for k, v in json.loads(request_info).items(): + os.environ[k] = v def make_url(template_name, base_url, path):