From caceb97e4d24c78c83880b55b15252664e822c9b Mon Sep 17 00:00:00 2001 From: Matthew Seal Date: Tue, 11 Oct 2022 07:12:45 -0700 Subject: [PATCH] Dataframe renaming comms listener (#49) --- CHANGELOG.md | 1 + src/dx/comms.py | 71 ++++++++++++++++--- src/dx/settings.py | 1 + tests/{test_comm.py => test_resample_comm.py} | 12 ++-- 4 files changed, 68 insertions(+), 17 deletions(-) rename tests/{test_comm.py => test_resample_comm.py} (79%) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8eb042e3..d6042025 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ All notable changes will be documented here. ### Changed - `STRINGIFY_INDEX_VALUES` is `False` by default (index `.name` will still be a string, but values will keep their original type) +- Added comms listener for renaming (SQL cell) dataframes. ## `1.2.0` _2022-08-21_ diff --git a/src/dx/comms.py b/src/dx/comms.py index e42cc746..767d86a7 100644 --- a/src/dx/comms.py +++ b/src/dx/comms.py @@ -1,3 +1,6 @@ +from typing import Optional + +import pandas as pd import structlog from IPython import get_ipython from IPython.core.interactiveshell import InteractiveShell @@ -9,43 +12,89 @@ # ref: https://jupyter-notebook.readthedocs.io/en/stable/comms.html#opening-a-comm-from-the-frontend -def target_func(comm, open_msg): +def resampler(comm, open_msg): + """ + Datalink resample request. + """ + @comm.on_msg def _recv(msg): - handle_comm_msg(msg) + # Is separate function to make testing easier. + handle_resample_comm(msg) -def handle_comm_msg(msg): - content = msg.get("content", {}) +def handle_resample_comm(msg): + content = msg.get("content") if not content: return - data = content.get("data", {}) + data = content.get("data") if not data: return - data = msg["content"]["data"] - if "display_id" in data and "filters" in data: # TODO: check for explicit resample value? msg = DEXResampleMessage.parse_obj(data) handle_resample(msg) +def renamer(comm, open_msg): + """Rename a SQL cell dataframe.""" + + @comm.on_msg + def _recv(msg): + handle_renaming_comm(msg) + + +def handle_renaming_comm(msg: dict, ipython_shell: Optional[InteractiveShell] = None): + """Implementation behind renaming a SQL cell dataframe via comms""" + content = msg.get("content") + if not content: + return + + data = content.get("data") + if not data: + return + + if "old_name" in data and "new_name" in data: + + if ipython_shell is None: # noqa + # shell will be passed in from test suite, otherwise go with global shell. + ipython_shell = get_ipython() + + value_to_rename = ipython_shell.user_ns.get(data["old_name"]) + + # Do not rename unless old_name mapped onto exactly a dataframe. + # + # (Handles case when it maps onto None, indicating that the old name + # hasn't been assigned to at all yet (i.e. user gestured to rename + # SQL cell dataframe name before the SQL cell has even been run the + # first time yet)) + # + if isinstance(value_to_rename, pd.DataFrame): + # New name can be empty string, indicating to drop reference to the var. + if data["new_name"]: + ipython_shell.user_ns[data["new_name"]] = value_to_rename + + # But old name will always be present in message. Delete it now. + del ipython_shell.user_ns[data["old_name"]] + + def register_comm(ipython_shell: InteractiveShell) -> None: """ Registers the comm target function with the IPython kernel. """ from dx.settings import get_settings - if not get_settings().ENABLE_DATALINK: - return - if getattr(ipython_shell, "kernel", None) is None: # likely a TerminalInteractiveShell return - ipython_shell.kernel.comm_manager.register_target("datalink", target_func) + if get_settings().ENABLE_DATALINK: + ipython_shell.kernel.comm_manager.register_target("datalink", resampler) + + if get_settings().ENABLE_RENAMER: + ipython_shell.kernel.comm_manager.register_target("rename", renamer) if (ipython := get_ipython()) is not None: diff --git a/src/dx/settings.py b/src/dx/settings.py index 1bb540d8..df908ea4 100644 --- a/src/dx/settings.py +++ b/src/dx/settings.py @@ -68,6 +68,7 @@ class Settings(BaseSettings): # controls dataframe variable tracking, hashing, and storing in sqlite ENABLE_DATALINK: bool = True + ENABLE_RENAMER: bool = True NUM_PAST_SAMPLES_TRACKED: int = 3 DB_LOCATION: str = ":memory:" diff --git a/tests/test_comm.py b/tests/test_resample_comm.py similarity index 79% rename from tests/test_comm.py rename to tests/test_resample_comm.py index fcd374cf..380bbc8b 100644 --- a/tests/test_comm.py +++ b/tests/test_resample_comm.py @@ -1,11 +1,11 @@ -from dx.comms import handle_comm_msg +from dx.comms import handle_resample_comm from dx.types import DEXResampleMessage -def test_handle_comm_msg(mocker): +def test_handle_resample_comm(mocker): """ Test that `handle_resample` is called with the correctly - formatted resample message type if the comm receives + formatted resample message type if the comm (`handle_resample_comm`) receives the right data structure. """ msg = { @@ -24,12 +24,12 @@ def test_handle_comm_msg(mocker): } } mock_handle_resample = mocker.patch("dx.comms.handle_resample") - handle_comm_msg(msg) + handle_resample_comm(msg) resample_msg = DEXResampleMessage.parse_obj(msg["content"]["data"]) mock_handle_resample.assert_called_once_with(resample_msg) -def test_handle_comm_msg_skipped(mocker): +def test_handle_resample_comm_skipped(mocker): """ Test that `handle_resample` is not called with invalid data. """ @@ -41,5 +41,5 @@ def test_handle_comm_msg_skipped(mocker): } } mock_handle_resample = mocker.patch("dx.comms.handle_resample") - handle_comm_msg(msg) + handle_resample_comm(msg) mock_handle_resample.assert_not_called()