Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dataframe renaming comms listener #49

Merged
merged 10 commits into from
Oct 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ All notable changes will be documented here.

### Changed
- `STRINGIFY_INDEX_VALUES` is `False` by default (index `.name` will still be a string, but values will keep their original type)
- Added comms listener for renaming (SQL cell) dataframes.

## `1.2.0`
_2022-08-21_
Expand Down
71 changes: 60 additions & 11 deletions src/dx/comms.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from typing import Optional

import pandas as pd
import structlog
from IPython import get_ipython
from IPython.core.interactiveshell import InteractiveShell
Expand All @@ -9,43 +12,89 @@


# ref: https://jupyter-notebook.readthedocs.io/en/stable/comms.html#opening-a-comm-from-the-frontend
def target_func(comm, open_msg):
def resampler(comm, open_msg):
"""
Datalink resample request.
"""

@comm.on_msg
def _recv(msg):
handle_comm_msg(msg)
# Is separate function to make testing easier.
handle_resample_comm(msg)


def handle_comm_msg(msg):
content = msg.get("content", {})
def handle_resample_comm(msg):
content = msg.get("content")
if not content:
return

data = content.get("data", {})
data = content.get("data")
if not data:
return

data = msg["content"]["data"]

if "display_id" in data and "filters" in data:
# TODO: check for explicit resample value?
msg = DEXResampleMessage.parse_obj(data)
handle_resample(msg)


def renamer(comm, open_msg):
"""Rename a SQL cell dataframe."""

@comm.on_msg
def _recv(msg):
handle_renaming_comm(msg)


def handle_renaming_comm(msg: dict, ipython_shell: Optional[InteractiveShell] = None):
"""Implementation behind renaming a SQL cell dataframe via comms"""
content = msg.get("content")
if not content:
return

data = content.get("data")
if not data:
return

if "old_name" in data and "new_name" in data:

if ipython_shell is None: # noqa
# shell will be passed in from test suite, otherwise go with global shell.
ipython_shell = get_ipython()

value_to_rename = ipython_shell.user_ns.get(data["old_name"])

# Do not rename unless old_name mapped onto exactly a dataframe.
#
# (Handles case when it maps onto None, indicating that the old name
# hasn't been assigned to at all yet (i.e. user gestured to rename
# SQL cell dataframe name before the SQL cell has even been run the
# first time yet))
#
if isinstance(value_to_rename, pd.DataFrame):
# New name can be empty string, indicating to drop reference to the var.
if data["new_name"]:
ipython_shell.user_ns[data["new_name"]] = value_to_rename

# But old name will always be present in message. Delete it now.
del ipython_shell.user_ns[data["old_name"]]


def register_comm(ipython_shell: InteractiveShell) -> None:
"""
Registers the comm target function with the IPython kernel.
"""
from dx.settings import get_settings

if not get_settings().ENABLE_DATALINK:
return

if getattr(ipython_shell, "kernel", None) is None:
# likely a TerminalInteractiveShell
return

ipython_shell.kernel.comm_manager.register_target("datalink", target_func)
if get_settings().ENABLE_DATALINK:
ipython_shell.kernel.comm_manager.register_target("datalink", resampler)

if get_settings().ENABLE_RENAMER:
ipython_shell.kernel.comm_manager.register_target("rename", renamer)


if (ipython := get_ipython()) is not None:
Expand Down
1 change: 1 addition & 0 deletions src/dx/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class Settings(BaseSettings):

# controls dataframe variable tracking, hashing, and storing in sqlite
ENABLE_DATALINK: bool = True
ENABLE_RENAMER: bool = True
NUM_PAST_SAMPLES_TRACKED: int = 3
DB_LOCATION: str = ":memory:"

Expand Down
12 changes: 6 additions & 6 deletions tests/test_comm.py → tests/test_resample_comm.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from dx.comms import handle_comm_msg
from dx.comms import handle_resample_comm
from dx.types import DEXResampleMessage


def test_handle_comm_msg(mocker):
def test_handle_resample_comm(mocker):
"""
Test that `handle_resample` is called with the correctly
formatted resample message type if the comm receives
formatted resample message type if the comm (`handle_resample_comm`) receives
the right data structure.
"""
msg = {
Expand All @@ -24,12 +24,12 @@ def test_handle_comm_msg(mocker):
}
}
mock_handle_resample = mocker.patch("dx.comms.handle_resample")
handle_comm_msg(msg)
handle_resample_comm(msg)
resample_msg = DEXResampleMessage.parse_obj(msg["content"]["data"])
mock_handle_resample.assert_called_once_with(resample_msg)


def test_handle_comm_msg_skipped(mocker):
def test_handle_resample_comm_skipped(mocker):
"""
Test that `handle_resample` is not called with invalid data.
"""
Expand All @@ -41,5 +41,5 @@ def test_handle_comm_msg_skipped(mocker):
}
}
mock_handle_resample = mocker.patch("dx.comms.handle_resample")
handle_comm_msg(msg)
handle_resample_comm(msg)
mock_handle_resample.assert_not_called()