Skip to content

Commit

Permalink
feat: make request info available in env for preheated kernels (#1109)
Browse files Browse the repository at this point in the history
* feat: make request info available in env for preheated kernels

The mechanism that was used for 'get_query_string()' is changed so
that all request info is sent. For backward compatibility
'get_query_string()' is reimplemented by using the changed code.

* refactor: no need to make a copy of environ in this case

Co-authored-by: Maarten Breddels <[email protected]>

* docs: add documentation for wait_for_request

* refactor: remove backward compatibility for get_query_string

Co-authored-by: Maarten Breddels <[email protected]>
  • Loading branch information
mariobuikhuizen and maartenbreddels authored Mar 18, 2022
1 parent 35fb38e commit 9d4063e
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 40 deletions.
10 changes: 6 additions & 4 deletions docs/source/customize.rst
Original file line number Diff line number Diff line change
Expand Up @@ -399,14 +399,16 @@ In normal mode, Voilà users can get the `query string` at run time through the
import os
query_string = os.getenv('QUERY_STRING')
In preheating kernel mode, users can just replace the ``os.getenv`` call with the helper ``get_query_string`` from ``voila.utils``
In preheating kernel mode, users can prepend with ``wait_for_request`` from ``voila.utils``

.. code-block:: python
from voila.utils import get_query_string
query_string = get_query_string()
import os
from voila.utils import wait_for_request
wait_for_request()
query_string = os.getenv('QUERY_STRING')
``get_query_string`` will pause the execution of the notebook in the preheated kernel at this cell and wait for an actual user to connect to Voilà, then ``get_query_string`` will return the URL `query string` and continue the execution of the remaining cells.
``wait_for_request`` will pause the execution of the notebook in the preheated kernel at this cell and wait for an actual user to connect to Voilà, set the request info environment variables and then continue the execution of the remaining cells.

If the Voilà websocket handler is not started with the default protocol (`ws`), the default IP address (`127.0.0.1`) or the default port (`8866`), users need to provide these values through the environment variables ``VOILA_APP_PROTOCOL``, ``VOILA_APP_IP`` and ``VOILA_APP_PORT``. The easiest way is to set these variables in the `voila.json` configuration file, for example:

Expand Down
4 changes: 2 additions & 2 deletions voila/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
from .exporter import VoilaExporter
from .shutdown_kernel_handler import VoilaShutdownKernelHandler
from .voila_kernel_manager import voila_kernel_manager_factory
from .query_parameters_handler import QueryStringSocketHandler
from .request_info_handler import RequestInfoSocketHandler
from .utils import create_include_assets_functions

_kernel_id_regex = r"(?P<kernel_id>\w+-\w+-\w+-\w+-\w+)"
Expand Down Expand Up @@ -500,7 +500,7 @@ def start(self):
handlers.append(
(
url_path_join(self.server_url, r'/voila/query/%s' % _kernel_id_regex),
QueryStringSocketHandler
RequestInfoSocketHandler
)
)
# Serving notebook extensions
Expand Down
23 changes: 12 additions & 11 deletions voila/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from ._version import __version__
from .notebook_renderer import NotebookRenderer
from .query_parameters_handler import QueryStringSocketHandler
from .request_info_handler import RequestInfoSocketHandler
from .utils import ENV_VARIABLE, create_include_assets_functions


Expand Down Expand Up @@ -80,25 +80,25 @@ async def get_generator(self, path=None):
cwd = os.path.dirname(notebook_path)

# Adding request uri to kernel env
kernel_env = os.environ.copy()
kernel_env[ENV_VARIABLE.SCRIPT_NAME] = self.request.path
kernel_env[
request_info = dict()
request_info[ENV_VARIABLE.SCRIPT_NAME] = self.request.path
request_info[
ENV_VARIABLE.PATH_INFO
] = '' # would be /foo/bar if voila.ipynb/foo/bar was supported
kernel_env[ENV_VARIABLE.QUERY_STRING] = str(self.request.query)
kernel_env[ENV_VARIABLE.SERVER_SOFTWARE] = 'voila/{}'.format(__version__)
kernel_env[ENV_VARIABLE.SERVER_PROTOCOL] = str(self.request.version)
request_info[ENV_VARIABLE.QUERY_STRING] = str(self.request.query)
request_info[ENV_VARIABLE.SERVER_SOFTWARE] = 'voila/{}'.format(__version__)
request_info[ENV_VARIABLE.SERVER_PROTOCOL] = str(self.request.version)
host, port = split_host_and_port(self.request.host.lower())
kernel_env[ENV_VARIABLE.SERVER_PORT] = str(port) if port else ''
kernel_env[ENV_VARIABLE.SERVER_NAME] = host
request_info[ENV_VARIABLE.SERVER_PORT] = str(port) if port else ''
request_info[ENV_VARIABLE.SERVER_NAME] = host
# Add HTTP Headers as env vars following rfc3875#section-4.1.18
if len(self.voila_configuration.http_header_envs) > 0:
for header_name in self.request.headers:
config_headers_lower = [header.lower() for header in self.voila_configuration.http_header_envs]
# Use case insensitive comparison of header names as per rfc2616#section-4.2
if header_name.lower() in config_headers_lower:
env_name = f'HTTP_{header_name.upper().replace("-", "_")}'
kernel_env[env_name] = self.request.headers.get(header_name)
request_info[env_name] = self.request.headers.get(header_name)

template_arg = self.get_argument("voila-template", None)
theme_arg = self.get_argument("voila-theme", None)
Expand Down Expand Up @@ -132,7 +132,7 @@ async def get_generator(self, path=None):
notebook_name=notebook_path,
)

QueryStringSocketHandler.send_updates({'kernel_id': kernel_id, 'payload': self.request.query})
RequestInfoSocketHandler.send_updates({'kernel_id': kernel_id, 'payload': request_info})
# Send rendered cell to frontend
if len(rendered_cache) > 0:
yield ''.join(rendered_cache)
Expand Down Expand Up @@ -183,6 +183,7 @@ def time_out():

return '<script>voila_heartbeat()</script>\n'

kernel_env = {**os.environ, **request_info}
kernel_env[ENV_VARIABLE.VOILA_PREHEAT] = 'False'
kernel_env[ENV_VARIABLE.VOILA_BASE_URL] = self.base_url
kernel_id = await ensure_async(
Expand Down
22 changes: 11 additions & 11 deletions voila/query_parameters_handler.py → voila/request_info_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@
from typing import Dict


class QueryStringSocketHandler(WebSocketHandler):
"""A websocket handler used to provide the query string
class RequestInfoSocketHandler(WebSocketHandler):
"""A websocket handler used to provide the request info
assocciated with kernel ids in preheat kernel mode.
Class variables
---------------
- _waiters : A dictionary which holds the `websocket` connection
assocciated with the kernel id.
- cache : A dictionary which holds the query string assocciated
- cache : A dictionary which holds the request info assocciated
with the kernel id.
"""
_waiters = dict()
Expand All @@ -26,28 +26,28 @@ def open(self, kernel_id: str) -> None:
kernel_id (str): Kernel id used by the notebook when it opens
the websocket connection.
"""
QueryStringSocketHandler._waiters[kernel_id] = self
RequestInfoSocketHandler._waiters[kernel_id] = self
if kernel_id in self._cache:
self.write_message(self._cache[kernel_id])

def on_close(self) -> None:
for k_id, waiter in QueryStringSocketHandler._waiters.items():
for k_id, waiter in RequestInfoSocketHandler._waiters.items():
if waiter == self:
break
del QueryStringSocketHandler._waiters[k_id]
del RequestInfoSocketHandler._waiters[k_id]

@classmethod
def send_updates(cls: 'QueryStringSocketHandler', msg: Dict) -> None:
"""Class method used to dispath the query string to the waiting
notebook. This method is called in `VoilaHandler` when the query
string becomes available.
def send_updates(cls: 'RequestInfoSocketHandler', msg: Dict) -> None:
"""Class method used to dispath the request info to the waiting
notebook. This method is called in `VoilaHandler` when the request
info becomes available.
If this method is called before the opening of websocket connection,
`msg` is stored in `_cache0` and the message will be dispatched when
a notebook with coresponding kernel id is connected.
Args:
- msg (Dict): this dictionary contains the `kernel_id` to identify
the waiting notebook and `payload` is the query string.
the waiting notebook and `payload` is the request info.
"""
kernel_id = msg['kernel_id']
payload = msg['payload']
Expand Down
26 changes: 14 additions & 12 deletions voila/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import threading
from enum import Enum
from typing import Awaitable
import json

import websockets

Expand Down Expand Up @@ -58,29 +59,29 @@ def get_server_root_dir(settings):
return root_dir


async def _get_query_string(ws_url: str) -> Awaitable:
async def _get_request_info(ws_url: str) -> Awaitable:
async with websockets.connect(ws_url) as websocket:
qs = await websocket.recv()
return qs
ri = await websocket.recv()
return ri


def get_query_string(url: str = None) -> str:
def wait_for_request(url: str = None) -> str:
"""Helper function to pause the execution of notebook and wait for
the query string.
the pre-heated kernel to be used and all request info is added to
the environment.
Args:
url (str, optional): Address to get user query string, if it is not
url (str, optional): Address to get request info, if it is not
provided, `voila` will figure out from the environment variables.
Defaults to None.
Returns: The query string provided by `QueryStringSocketHandler`.
"""

preheat_mode = os.getenv(ENV_VARIABLE.VOILA_PREHEAT, 'False')
if preheat_mode == 'False':
return os.getenv(ENV_VARIABLE.QUERY_STRING)
return

query_string = None
request_info = None
if url is None:
protocol = os.getenv(ENV_VARIABLE.VOILA_APP_PROTOCOL, 'ws')
server_ip = os.getenv(ENV_VARIABLE.VOILA_APP_IP, '127.0.0.1')
Expand All @@ -92,9 +93,9 @@ def get_query_string(url: str = None) -> str:
ws_url = f'{url}/{kernel_id}'

def inner():
nonlocal query_string
nonlocal request_info
loop = asyncio.new_event_loop()
query_string = loop.run_until_complete(_get_query_string(ws_url))
request_info = loop.run_until_complete(_get_request_info(ws_url))

thread = threading.Thread(target=inner)
try:
Expand All @@ -103,7 +104,8 @@ def inner():
except (KeyboardInterrupt, SystemExit):
asyncio.get_event_loop().stop()

return query_string
for k, v in json.loads(request_info).items():
os.environ[k] = v


def make_url(template_name, base_url, path):
Expand Down

0 comments on commit 9d4063e

Please sign in to comment.