From e55d49f7936265aa5af0758ace46af7a95312f87 Mon Sep 17 00:00:00 2001 From: Dave Shoup Date: Fri, 23 Sep 2022 19:20:58 -0400 Subject: [PATCH 1/7] add duckdb --- poetry.lock | 28 +++++++++++++++++++++++++++- pyproject.toml | 2 ++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/poetry.lock b/poetry.lock index 42208e17..50deb867 100644 --- a/poetry.lock +++ b/poetry.lock @@ -235,6 +235,30 @@ category = "dev" optional = false python-versions = "*" +[[package]] +name = "duckdb" +version = "0.5.1" +description = "DuckDB embedded database" +category = "main" +optional = false +python-versions = "*" + +[package.dependencies] +numpy = ">=1.14" + +[[package]] +name = "duckdb-engine" +version = "0.6.4" +description = "SQLAlchemy driver for duckdb" +category = "main" +optional = false +python-versions = ">=3.6.1" + +[package.dependencies] +duckdb = ">=0.4.0" +numpy = "*" +sqlalchemy = ">=1.3.19,<2.0.0" + [[package]] name = "entrypoints" version = "0.4" @@ -1594,7 +1618,7 @@ docs = ["mkdocs", "mkdocs-material", "mkdocs-jupyter", "mkdocstrings", "mkdocstr [metadata] lock-version = "1.1" python-versions = "^3.8" -content-hash = "df1aa0664713b2b52e946d22efac5ba366dc0bf0dea1ebc31314589702c87309" +content-hash = "dbdf7f19c805328ea053818514b4cca3e23c8f498c8a557436e4ec817d82e922" [metadata.files] appnope = [ @@ -1718,6 +1742,8 @@ defusedxml = [ {file = "defusedxml-0.7.1.tar.gz", hash = "sha256:1bb3032db185915b62d7c6209c5a8792be6a32ab2fedacc84e01b52c51aa3e69"}, ] distlib = [] +duckdb = [] +duckdb-engine = [] entrypoints = [ {file = "entrypoints-0.4-py3-none-any.whl", hash = "sha256:f174b5ff827504fd3cd97cc3f8649f3693f51538c7e4bdf3ef002c8429d42f9f"}, {file = "entrypoints-0.4.tar.gz", hash = "sha256:b706eddaa9218a19ebcd67b56818f05bb27589b1ca9e8d797b74affad4ccacd4"}, diff --git a/pyproject.toml b/pyproject.toml index 45cc8ade..d0acdf67 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,8 @@ mkdocs-material = {version = "^8.3.9", optional = true} mkdocs-jupyter = {version = "^0.21.0", optional = true} mkdocstrings = {version = "^0.19.0", optional = true} mkdocstrings-python = {version = "^0.7.1", optional = true} +duckdb = "^0.5.1" +duckdb-engine = "^0.6.4" [tool.poetry.dev-dependencies] black = "^22.6.0" From ee33780a0d0708e93faf1b5df90087a9fcf7ca8a Mon Sep 17 00:00:00 2001 From: Dave Shoup Date: Fri, 23 Sep 2022 19:21:07 -0400 Subject: [PATCH 2/7] first test, fingers crossed --- src/dx/filtering.py | 11 ++++++----- src/dx/formatters/main.py | 19 +++++++++++++++++-- src/dx/utils/tracking.py | 20 ++------------------ tests/conftest.py | 8 ++++++++ tests/test_datatypes.py | 26 ++++++++++++++++---------- 5 files changed, 49 insertions(+), 35 deletions(-) diff --git a/src/dx/filtering.py b/src/dx/filtering.py index d42cd47c..26719020 100644 --- a/src/dx/filtering.py +++ b/src/dx/filtering.py @@ -7,7 +7,7 @@ from dx.sampling import get_df_dimensions from dx.settings import get_settings, settings_context from dx.types import DEXFilterSettings, DEXResampleMessage -from dx.utils.tracking import DXDF_CACHE, SUBSET_TO_DISPLAY_ID, generate_df_hash, sql_engine +from dx.utils.tracking import DXDF_CACHE, SUBSET_TO_DISPLAY_ID, db_connection, generate_df_hash logger = structlog.get_logger(__name__) @@ -65,12 +65,13 @@ def update_display_id( row_limit = limit or settings.DISPLAY_MAX_ROWS dxdf = DXDF_CACHE[display_id] - query_string = sql_filter.format(table_name=dxdf.sql_table) + query_string = sql_filter.format(table_name=dxdf.variable_name) logger.debug(f"sql query string: {query_string}") - new_df = pd.read_sql(query_string, sql_engine) - with sql_engine.connect() as conn: - orig_df_count = conn.execute(f"SELECT COUNT (*) FROM {dxdf.sql_table}").scalar() + new_df: pd.DataFrame = db_connection.execute(query_string).df() + count_resp = db_connection.execute(f"SELECT COUNT(*) FROM {dxdf.variable_name}").fetchone() + # should return a tuple of (count,) + orig_df_count = count_resp[0] logger.debug(f"filtered to {len(new_df)}/{orig_df_count} row(s)") metadata = store_sample_to_history(new_df, display_id=display_id, filters=filters) diff --git a/src/dx/formatters/main.py b/src/dx/formatters/main.py index 16836830..7dbe6e39 100644 --- a/src/dx/formatters/main.py +++ b/src/dx/formatters/main.py @@ -14,7 +14,7 @@ from dx.types import DXDisplayMode from dx.utils.datatypes import to_dataframe from dx.utils.formatting import generate_metadata, is_default_index, normalize_index_and_columns -from dx.utils.tracking import DXDF_CACHE, SUBSET_TO_DISPLAY_ID, DXDataFrame, store_in_sqlite +from dx.utils.tracking import DXDF_CACHE, SUBSET_TO_DISPLAY_ID, DXDataFrame, db_connection logger = structlog.get_logger(__name__) @@ -57,7 +57,22 @@ def datalink_processing( # this needs to happen after sending to the frontend # so the user doesn't wait as long for writing larger datasets if not parent_display_id: - store_in_sqlite(dxdf.sql_table, dxdf.df) + if dxdf.variable_name.startswith("unk_dataframe"): + # it wasn't assigned to a variable but we still + # want it to be available for push-down filtering + logger.debug( + f"registering unassigned dataframe to duckdb and adding to user_ns: {dxdf.variable_name}" + ) + ipython_shell.user_ns[dxdf.variable_name] = dxdf.df + db_connection.register(dxdf.variable_name, dxdf.df) + else: + # duckdb should already be tracking this variable name + count_resp = db_connection.execute( + f"SELECT COUNT(*) FROM {dxdf.variable_name}" + ).fetchone() + logger.debug( + f"duckdb is already tracking {dxdf.variable_name}, and shows {count_resp[0]} row(s)" + ) return payload, metadata diff --git a/src/dx/utils/tracking.py b/src/dx/utils/tracking.py index 52fc1030..8200ce03 100644 --- a/src/dx/utils/tracking.py +++ b/src/dx/utils/tracking.py @@ -2,12 +2,12 @@ import uuid from typing import List, Optional +import duckdb import pandas as pd import structlog from IPython import get_ipython from IPython.core.interactiveshell import InteractiveShell from pandas.util import hash_pandas_object -from sqlalchemy import create_engine from dx.settings import get_settings from dx.utils.datatypes import has_numeric_strings, is_sequence_series @@ -15,7 +15,7 @@ from dx.utils.formatting import generate_metadata, is_default_index, normalize_index_and_columns logger = structlog.get_logger(__name__) -sql_engine = create_engine("sqlite://", echo=False) +db_connection = duckdb.connect(database=":memory:") settings = get_settings() @@ -169,19 +169,3 @@ def get_df_variable_name( logger.debug("no variables found matching this dataframe") df_uuid = f"unk_dataframe_{uuid.uuid4()}".replace("-", "") return df_uuid - - -def store_in_sqlite(table_name: str, df: pd.DataFrame): - logger.debug(f"{df.columns=}") - tracking_df = df.copy() - - logger.debug(f"writing to `{table_name}` table in sqlite") - with sql_engine.begin() as conn: - num_written_rows = tracking_df.to_sql( - table_name, - con=conn, - if_exists="replace", - index=True, # this is the default, but just to be explicit - ) - logger.debug(f"wrote {num_written_rows} row(s) to `{table_name}` table") - return num_written_rows diff --git a/tests/conftest.py b/tests/conftest.py index 2f7d3c4d..6bf94095 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,4 @@ +import duckdb import numpy as np import pandas as pd import pytest @@ -38,6 +39,13 @@ def get_ipython() -> TerminalInteractiveShell: return shell +@pytest.fixture +def sample_db_connection() -> duckdb.DuckDBPyConnection: + conn = duckdb.connect(":memory:") + yield conn + conn.close() + + @pytest.fixture def sample_dataframe() -> pd.DataFrame: df = pd.DataFrame( diff --git a/tests/test_datatypes.py b/tests/test_datatypes.py index 81ce3d0c..cbb80709 100644 --- a/tests/test_datatypes.py +++ b/tests/test_datatypes.py @@ -2,9 +2,10 @@ Tests to ensure various data types can be sent functions to - build the table schema and payload/metadata body for each display formatter - hash the dataframe for tracking -- write to sqlite for tracking/filtering +- write to the database for tracking/filtering """ +import duckdb import pandas as pd import pytest from pandas.io.json import build_table_schema @@ -19,7 +20,7 @@ random_dataframe, ) from dx.utils.formatting import clean_column_values -from dx.utils.tracking import generate_df_hash, sql_engine, store_in_sqlite +from dx.utils.tracking import generate_df_hash @pytest.mark.parametrize("dtype", SORTED_DX_DATATYPES) @@ -135,37 +136,42 @@ def test_generate_df_hash(dtype: str): @pytest.mark.xfail(reason="only for dev") @pytest.mark.parametrize("dtype", SORTED_DX_DATATYPES) -def test_to_sql(dtype: str): +def test_to_sql(dtype: str, sample_db_connection: duckdb.DuckDBPyConnection): """ DEV: Test which data types pass/fail when passed directly through .to_sql() with the sqlalchemy engine. """ params = {dt: False for dt in SORTED_DX_DATATYPES} params[dtype] = True + df = random_dataframe(**params) + try: - with sql_engine.connect() as conn: - num_rows = df.to_sql("test", conn, if_exists="replace") + sample_db_connection.register(f"{dtype}_test", df) except Exception as e: assert False, f"{dtype} failed with {e}" + count_resp = sample_db_connection.execute(f"SELECT COUNT(*) FROM {dtype}_test").fetchone() + num_rows = count_resp[0] assert num_rows == df.shape[0] @pytest.mark.parametrize("dtype", SORTED_DX_DATATYPES) -def test_store_in_sqlite(dtype: str): +def test_store_in_db(dtype: str, sample_db_connection: duckdb.DuckDBPyConnection): """ Test that we've correctly handled data types before storing in sqlite. """ params = {dt: False for dt in SORTED_DX_DATATYPES} params[dtype] = True + df = random_dataframe(**params) + for col in df.columns: df[col] = clean_column_values(df[col]) + try: - num_rows = store_in_sqlite(f"{dtype}_test", df) + sample_db_connection.register(f"{dtype}_test", df) except Exception as e: assert False, f"{dtype} failed with {e}" + count_resp = sample_db_connection.execute(f"SELECT COUNT(*) FROM {dtype}_test").fetchone() + num_rows = count_resp[0] assert num_rows == df.shape[0] - - -# TODO: test that we can convert back to original datatypes after read_sql? From 9f345c32b8118a78cc0cad3990eea71c85980208 Mon Sep 17 00:00:00 2001 From: Dave Shoup Date: Fri, 23 Sep 2022 20:28:30 -0400 Subject: [PATCH 3/7] first test, fingers crossed --- src/dx/filtering.py | 13 +++++++------ src/dx/formatters/main.py | 7 ++++--- src/dx/settings.py | 1 + src/dx/utils/tracking.py | 20 ++------------------ tests/conftest.py | 8 ++++++++ tests/test_datatypes.py | 26 ++++++++++++++++---------- 6 files changed, 38 insertions(+), 37 deletions(-) diff --git a/src/dx/filtering.py b/src/dx/filtering.py index d42cd47c..de91121d 100644 --- a/src/dx/filtering.py +++ b/src/dx/filtering.py @@ -7,10 +7,10 @@ from dx.sampling import get_df_dimensions from dx.settings import get_settings, settings_context from dx.types import DEXFilterSettings, DEXResampleMessage -from dx.utils.tracking import DXDF_CACHE, SUBSET_TO_DISPLAY_ID, generate_df_hash, sql_engine +from dx.utils.tracking import DXDF_CACHE, SUBSET_TO_DISPLAY_ID, generate_df_hash, get_db_connection logger = structlog.get_logger(__name__) - +db_connection = get_db_connection() settings = get_settings() @@ -65,12 +65,13 @@ def update_display_id( row_limit = limit or settings.DISPLAY_MAX_ROWS dxdf = DXDF_CACHE[display_id] - query_string = sql_filter.format(table_name=dxdf.sql_table) + query_string = sql_filter.format(table_name=dxdf.variable_name) logger.debug(f"sql query string: {query_string}") - new_df = pd.read_sql(query_string, sql_engine) - with sql_engine.connect() as conn: - orig_df_count = conn.execute(f"SELECT COUNT (*) FROM {dxdf.sql_table}").scalar() + new_df: pd.DataFrame = db_connection.execute(query_string).df() + count_resp = db_connection.execute(f"SELECT COUNT(*) FROM {dxdf.variable_name}").fetchone() + # should return a tuple of (count,) + orig_df_count = count_resp[0] logger.debug(f"filtered to {len(new_df)}/{orig_df_count} row(s)") metadata = store_sample_to_history(new_df, display_id=display_id, filters=filters) diff --git a/src/dx/formatters/main.py b/src/dx/formatters/main.py index 16836830..20078265 100644 --- a/src/dx/formatters/main.py +++ b/src/dx/formatters/main.py @@ -14,10 +14,10 @@ from dx.types import DXDisplayMode from dx.utils.datatypes import to_dataframe from dx.utils.formatting import generate_metadata, is_default_index, normalize_index_and_columns -from dx.utils.tracking import DXDF_CACHE, SUBSET_TO_DISPLAY_ID, DXDataFrame, store_in_sqlite +from dx.utils.tracking import DXDF_CACHE, SUBSET_TO_DISPLAY_ID, DXDataFrame, get_db_connection logger = structlog.get_logger(__name__) - +db_connection = get_db_connection() DEFAULT_IPYTHON_DISPLAY_FORMATTER = DisplayFormatter() IN_NOTEBOOK_ENV = False @@ -57,7 +57,8 @@ def datalink_processing( # this needs to happen after sending to the frontend # so the user doesn't wait as long for writing larger datasets if not parent_display_id: - store_in_sqlite(dxdf.sql_table, dxdf.df) + logger.debug(f"registering `{dxdf.variable_name}` to duckdb") + db_connection.register(dxdf.variable_name, dxdf.df.reset_index()) return payload, metadata diff --git a/src/dx/settings.py b/src/dx/settings.py index 97f7289d..1bb540d8 100644 --- a/src/dx/settings.py +++ b/src/dx/settings.py @@ -69,6 +69,7 @@ class Settings(BaseSettings): # controls dataframe variable tracking, hashing, and storing in sqlite ENABLE_DATALINK: bool = True NUM_PAST_SAMPLES_TRACKED: int = 3 + DB_LOCATION: str = ":memory:" @validator("RENDERABLE_OBJECTS", pre=True, always=True) def validate_renderables(cls, vals): diff --git a/src/dx/utils/tracking.py b/src/dx/utils/tracking.py index 52fc1030..8200ce03 100644 --- a/src/dx/utils/tracking.py +++ b/src/dx/utils/tracking.py @@ -2,12 +2,12 @@ import uuid from typing import List, Optional +import duckdb import pandas as pd import structlog from IPython import get_ipython from IPython.core.interactiveshell import InteractiveShell from pandas.util import hash_pandas_object -from sqlalchemy import create_engine from dx.settings import get_settings from dx.utils.datatypes import has_numeric_strings, is_sequence_series @@ -15,7 +15,7 @@ from dx.utils.formatting import generate_metadata, is_default_index, normalize_index_and_columns logger = structlog.get_logger(__name__) -sql_engine = create_engine("sqlite://", echo=False) +db_connection = duckdb.connect(database=":memory:") settings = get_settings() @@ -169,19 +169,3 @@ def get_df_variable_name( logger.debug("no variables found matching this dataframe") df_uuid = f"unk_dataframe_{uuid.uuid4()}".replace("-", "") return df_uuid - - -def store_in_sqlite(table_name: str, df: pd.DataFrame): - logger.debug(f"{df.columns=}") - tracking_df = df.copy() - - logger.debug(f"writing to `{table_name}` table in sqlite") - with sql_engine.begin() as conn: - num_written_rows = tracking_df.to_sql( - table_name, - con=conn, - if_exists="replace", - index=True, # this is the default, but just to be explicit - ) - logger.debug(f"wrote {num_written_rows} row(s) to `{table_name}` table") - return num_written_rows diff --git a/tests/conftest.py b/tests/conftest.py index 2f7d3c4d..6bf94095 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,4 @@ +import duckdb import numpy as np import pandas as pd import pytest @@ -38,6 +39,13 @@ def get_ipython() -> TerminalInteractiveShell: return shell +@pytest.fixture +def sample_db_connection() -> duckdb.DuckDBPyConnection: + conn = duckdb.connect(":memory:") + yield conn + conn.close() + + @pytest.fixture def sample_dataframe() -> pd.DataFrame: df = pd.DataFrame( diff --git a/tests/test_datatypes.py b/tests/test_datatypes.py index 81ce3d0c..cbb80709 100644 --- a/tests/test_datatypes.py +++ b/tests/test_datatypes.py @@ -2,9 +2,10 @@ Tests to ensure various data types can be sent functions to - build the table schema and payload/metadata body for each display formatter - hash the dataframe for tracking -- write to sqlite for tracking/filtering +- write to the database for tracking/filtering """ +import duckdb import pandas as pd import pytest from pandas.io.json import build_table_schema @@ -19,7 +20,7 @@ random_dataframe, ) from dx.utils.formatting import clean_column_values -from dx.utils.tracking import generate_df_hash, sql_engine, store_in_sqlite +from dx.utils.tracking import generate_df_hash @pytest.mark.parametrize("dtype", SORTED_DX_DATATYPES) @@ -135,37 +136,42 @@ def test_generate_df_hash(dtype: str): @pytest.mark.xfail(reason="only for dev") @pytest.mark.parametrize("dtype", SORTED_DX_DATATYPES) -def test_to_sql(dtype: str): +def test_to_sql(dtype: str, sample_db_connection: duckdb.DuckDBPyConnection): """ DEV: Test which data types pass/fail when passed directly through .to_sql() with the sqlalchemy engine. """ params = {dt: False for dt in SORTED_DX_DATATYPES} params[dtype] = True + df = random_dataframe(**params) + try: - with sql_engine.connect() as conn: - num_rows = df.to_sql("test", conn, if_exists="replace") + sample_db_connection.register(f"{dtype}_test", df) except Exception as e: assert False, f"{dtype} failed with {e}" + count_resp = sample_db_connection.execute(f"SELECT COUNT(*) FROM {dtype}_test").fetchone() + num_rows = count_resp[0] assert num_rows == df.shape[0] @pytest.mark.parametrize("dtype", SORTED_DX_DATATYPES) -def test_store_in_sqlite(dtype: str): +def test_store_in_db(dtype: str, sample_db_connection: duckdb.DuckDBPyConnection): """ Test that we've correctly handled data types before storing in sqlite. """ params = {dt: False for dt in SORTED_DX_DATATYPES} params[dtype] = True + df = random_dataframe(**params) + for col in df.columns: df[col] = clean_column_values(df[col]) + try: - num_rows = store_in_sqlite(f"{dtype}_test", df) + sample_db_connection.register(f"{dtype}_test", df) except Exception as e: assert False, f"{dtype} failed with {e}" + count_resp = sample_db_connection.execute(f"SELECT COUNT(*) FROM {dtype}_test").fetchone() + num_rows = count_resp[0] assert num_rows == df.shape[0] - - -# TODO: test that we can convert back to original datatypes after read_sql? From f6a481ba897648684c3017e592e0e11eba5db542 Mon Sep 17 00:00:00 2001 From: Dave Shoup Date: Mon, 26 Sep 2022 13:01:58 -0400 Subject: [PATCH 4/7] add date/metric groupby filter fixtures --- tests/conftest.py | 38 +++++++++++++++++++++++++++++++++++--- 1 file changed, 35 insertions(+), 3 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 6bf94095..849ea6ed 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -114,7 +114,18 @@ def sample_long_wide_dataframe() -> pd.DataFrame: @pytest.fixture def sample_dex_date_filter(sample_random_dataframe: pd.DataFrame) -> dict: return { - "column": "date_column", + "column": "datetime_column", + "type": "DATE_FILTER", + "predicate": "between", + "start": sample_random_dataframe.datetime_column.min(), + "end": sample_random_dataframe.datetime_column.max(), + } + + +@pytest.fixture +def sample_dex_groupby_date_filter(sample_random_dataframe: pd.DataFrame) -> dict: + return { + "column": "datetime_column, min", "type": "DATE_FILTER", "predicate": "between", "start": sample_random_dataframe.datetime_column.min(), @@ -135,6 +146,19 @@ def sample_dex_metric_filter(sample_random_dataframe: pd.DataFrame) -> dict: } +@pytest.fixture +def sample_dex_groupby_metric_filter(sample_random_dataframe: pd.DataFrame) -> dict: + return { + "column": "float_column, min", + "type": "METRIC_FILTER", + "predicate": "between", + "value": [ + sample_random_dataframe.float_column.min(), + sample_random_dataframe.float_column.max(), + ], + } + + @pytest.fixture def sample_dex_dimension_filter(sample_random_dataframe: pd.DataFrame) -> dict: return { @@ -161,8 +185,16 @@ def sample_dex_filters( @pytest.fixture -def sample_display_id() -> str: - return "test-display-123" +def sample_dex_groupby_filters( + sample_dex_groupby_date_filter: dict, + sample_dex_groupby_metric_filter: dict, +) -> list: + return DEXFilterSettings( + filters=[ + sample_dex_groupby_date_filter, + sample_dex_groupby_metric_filter, + ] + ).filters @pytest.fixture From 32d8a72b47417073221e98870f61c94831471176 Mon Sep 17 00:00:00 2001 From: Dave Shoup Date: Mon, 26 Sep 2022 13:06:58 -0400 Subject: [PATCH 5/7] handle storing .index better to accomodate groupby/multiindex --- src/dx/filtering.py | 17 +++++++++++++---- src/dx/utils/tracking.py | 18 +++++++++--------- 2 files changed, 22 insertions(+), 13 deletions(-) diff --git a/src/dx/filtering.py b/src/dx/filtering.py index de91121d..9029e24c 100644 --- a/src/dx/filtering.py +++ b/src/dx/filtering.py @@ -3,6 +3,7 @@ import pandas as pd import structlog from IPython.display import update_display +from IPython.terminal.interactiveshell import InteractiveShell from dx.sampling import get_df_dimensions from dx.settings import get_settings, settings_context @@ -52,6 +53,7 @@ def update_display_id( output_variable_name: Optional[str] = None, limit: Optional[int] = None, cell_id: Optional[str] = None, + ipython_shell: Optional[InteractiveShell] = None, ) -> None: """ Filters the dataframe in the cell with the given display_id. @@ -76,11 +78,15 @@ def update_display_id( metadata = store_sample_to_history(new_df, display_id=display_id, filters=filters) - # resetting original index - new_df.set_index(dxdf.index_name, inplace=True) + # resetting original index if needed + if dxdf.index_name is not None: + new_df.set_index(dxdf.index_name, inplace=True) # convert back to original dtypes for col, dtype in dxdf.original_column_dtypes.items(): + if settings.FLATTEN_COLUMN_VALUES and isinstance(col, tuple): + # the dataframe in use originally had pd.MultiIndex columns + col = ", ".join(col) new_df[col] = new_df[col].astype(dtype) # this is associating the subset with the original dataframe, @@ -105,7 +111,10 @@ def update_display_id( ) -def handle_resample(msg: DEXResampleMessage) -> None: +def handle_resample( + msg: DEXResampleMessage, + ipython_shell: Optional[InteractiveShell] = None, +) -> None: raw_filters = msg.filters sample_size = msg.limit @@ -131,4 +140,4 @@ def handle_resample(msg: DEXResampleMessage) -> None: } ) - update_display_id(**update_params) + update_display_id(ipython_shell=ipython_shell, **update_params) diff --git a/src/dx/utils/tracking.py b/src/dx/utils/tracking.py index b714dc10..13b5ebe0 100644 --- a/src/dx/utils/tracking.py +++ b/src/dx/utils/tracking.py @@ -1,7 +1,7 @@ import hashlib import uuid from functools import lru_cache -from typing import List, Optional +from typing import List, Optional, Union import duckdb import pandas as pd @@ -11,8 +11,6 @@ from pandas.util import hash_pandas_object from dx.settings import get_settings -from dx.utils.datatypes import has_numeric_strings, is_sequence_series -from dx.utils.date_time import is_datetime_series from dx.utils.formatting import generate_metadata, is_default_index, normalize_index_and_columns logger = structlog.get_logger(__name__) @@ -64,17 +62,12 @@ def __init__( self.variable_name = get_df_variable_name(df, ipython_shell=ipython_shell) self.original_column_dtypes = df.dtypes.to_dict() - self.sequence_columns = [column for column in df.columns if is_sequence_series(df[column])] - self.datetime_columns = [ - c for c in df.columns if is_datetime_series(df[c]) and not has_numeric_strings(df[c]) - ] self.default_index_used = is_default_index(df.index) - self.index_name = df.index.name or df.index.names or "index" + self.index_name = get_df_index(df.index) self.df = normalize_index_and_columns(df) self.hash = generate_df_hash(self.df) - self.sql_table = f"{self.variable_name}_{self.hash}" self.display_id = SUBSET_TO_DISPLAY_ID.get(self.hash, str(uuid.uuid4())) self.metadata = generate_metadata(self.display_id) @@ -127,6 +120,13 @@ def generate_df_hash(df: pd.DataFrame) -> str: return hash_str +def get_df_index(index: Union[pd.Index, pd.MultiIndex]): + index_name = index.name + if index_name is None and isinstance(index, pd.MultiIndex): + index_name = index.names + return index_name + + def get_df_variable_name( df: pd.DataFrame, ipython_shell: Optional[InteractiveShell] = None, From ecc280efd75f99f5c396d6a3b72b52c4a8d13f79 Mon Sep 17 00:00:00 2001 From: Dave Shoup Date: Mon, 26 Sep 2022 13:07:12 -0400 Subject: [PATCH 6/7] add tests for push-down filtering --- tests/test_filtering.py | 90 ++++++++++++++++++++++++++++++++++++++++- tests/test_tracking.py | 27 +++++++++++++ 2 files changed, 116 insertions(+), 1 deletion(-) diff --git a/tests/test_filtering.py b/tests/test_filtering.py index a9d696f7..aeadefb4 100644 --- a/tests/test_filtering.py +++ b/tests/test_filtering.py @@ -1,6 +1,16 @@ -from dx.filtering import store_sample_to_history +import duckdb +import pandas as pd +import pytest +from IPython.terminal.interactiveshell import TerminalInteractiveShell + +from dx.filtering import handle_resample, store_sample_to_history +from dx.formatters.main import handle_format +from dx.settings import get_settings, settings_context +from dx.types import DEXResampleMessage from dx.utils.tracking import DXDF_CACHE, DXDataFrame +settings = get_settings() + def test_store_sample_to_history( sample_dxdataframe: DXDataFrame, @@ -21,3 +31,81 @@ def test_store_sample_to_history( assert sample_dxdataframe.metadata["datalink"]["applied_filters"] == sample_dex_filters assert sample_dxdataframe.metadata["datalink"]["sampling_time"] is not None + + +@pytest.mark.parametrize("display_mode", ["simple", "enhanced"]) +def test_resample_from_db( + mocker, + get_ipython: TerminalInteractiveShell, + sample_random_dataframe: pd.DataFrame, + sample_db_connection: duckdb.DuckDBPyConnection, + sample_dex_filters: list, + display_mode: str, +): + """ + Ensure dataframes stored in the kernel's local database + can be resampled with DEX-provided filters. + """ + get_ipython.user_ns["test_df"] = sample_random_dataframe + + mocker.patch("dx.formatters.main.db_connection", sample_db_connection) + mocker.patch("dx.filtering.db_connection", sample_db_connection) + + with settings_context(enable_datalink=True, display_mode=display_mode): + _, metadata = handle_format( + sample_random_dataframe, + ipython_shell=get_ipython, + ) + + resample_msg = DEXResampleMessage( + display_id=metadata[settings.MEDIA_TYPE]["display_id"], + filters=sample_dex_filters, + limit=50_000, + cell_id=None, + ) + try: + handle_resample( + resample_msg, + ipython_shell=get_ipython, + ) + except Exception as e: + assert False, f"Resample failed with error: {e}" + + +@pytest.mark.parametrize("display_mode", ["simple", "enhanced"]) +def test_resample_groupby_from_db( + mocker, + get_ipython: TerminalInteractiveShell, + sample_groupby_dataframe: pd.DataFrame, + sample_db_connection: duckdb.DuckDBPyConnection, + sample_dex_groupby_filters: list, + display_mode: str, +): + """ + Ensure dataframes stored in the kernel's local database + can be resampled with DEX-provided filters. + """ + get_ipython.user_ns["test_df"] = sample_groupby_dataframe + + mocker.patch("dx.formatters.main.db_connection", sample_db_connection) + mocker.patch("dx.filtering.db_connection", sample_db_connection) + + with settings_context(enable_datalink=True, display_mode=display_mode): + _, metadata = handle_format( + sample_groupby_dataframe, + ipython_shell=get_ipython, + ) + + resample_msg = DEXResampleMessage( + display_id=metadata[settings.MEDIA_TYPE]["display_id"], + filters=sample_dex_groupby_filters, + limit=50_000, + cell_id=None, + ) + try: + handle_resample( + resample_msg, + ipython_shell=get_ipython, + ) + except Exception as e: + assert False, f"Resample failed with error: {e}" diff --git a/tests/test_tracking.py b/tests/test_tracking.py index a78d2b5d..b6001e10 100644 --- a/tests/test_tracking.py +++ b/tests/test_tracking.py @@ -1,6 +1,9 @@ +import duckdb import pandas as pd from IPython.terminal.interactiveshell import TerminalInteractiveShell +from dx.formatters.main import handle_format +from dx.settings import settings_context from dx.utils.formatting import normalize_index_and_columns from dx.utils.tracking import DXDataFrame, generate_df_hash @@ -102,3 +105,27 @@ def test_dxdataframe_metadata( assert isinstance(datalink_metadata["applied_filters"], list) assert isinstance(datalink_metadata["sample_history"], list) + + +def test_store_in_db( + mocker, + get_ipython: TerminalInteractiveShell, + sample_random_dataframe: pd.DataFrame, + sample_db_connection: duckdb.DuckDBPyConnection, +): + """ + Ensure dataframes are stored as tables using the kernel's + local database connection. + """ + mocker.patch("dx.formatters.main.db_connection", sample_db_connection) + + get_ipython.user_ns["test_df"] = sample_random_dataframe + + with settings_context(enable_datalink=True): + handle_format( + sample_random_dataframe, + ipython_shell=get_ipython, + ) + + resp = sample_db_connection.execute("SELECT COUNT(*) FROM test_df").fetchone() + assert resp[0] == len(sample_random_dataframe) From adb4b6f0bbfece20e25d99d9ecd18471ed85ed94 Mon Sep 17 00:00:00 2001 From: Dave Shoup Date: Mon, 26 Sep 2022 13:14:30 -0400 Subject: [PATCH 7/7] poetry --- poetry.lock | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/poetry.lock b/poetry.lock index 50deb867..89b6b86a 100644 --- a/poetry.lock +++ b/poetry.lock @@ -108,7 +108,7 @@ dev = ["build (==0.8.0)", "flake8 (==4.0.1)", "hashin (==0.17.0)", "pip-tools (= [[package]] name = "certifi" -version = "2022.9.14" +version = "2022.9.24" description = "Python package for providing Mozilla's CA Bundle." category = "main" optional = false @@ -269,12 +269,15 @@ python-versions = ">=3.6" [[package]] name = "executing" -version = "1.0.0" +version = "1.1.0" description = "Get the currently executing AST node of a frame, and other information" category = "main" optional = false python-versions = "*" +[package.extras] +tests = ["rich", "littleutils", "pytest", "asttokens"] + [[package]] name = "faker" version = "14.2.1" @@ -399,7 +402,7 @@ docs = ["sphinx"] [[package]] name = "griffe" -version = "0.22.1" +version = "0.22.2" description = "Signatures for entire Python programs. Extract the structure, the frame, the skeleton of your project, to generate API documentation or find breaking changes in your API." category = "main" optional = true @@ -767,7 +770,7 @@ mkdocs = ">=1.1" [[package]] name = "mkdocs-jupyter" -version = "0.21.0" +version = "0.22.0" description = "Use Jupyter in mkdocs websites" category = "main" optional = true @@ -881,7 +884,7 @@ test = ["black", "check-manifest", "flake8", "ipykernel", "ipython", "ipywidgets [[package]] name = "nbconvert" -version = "6.5.3" +version = "6.5.4" description = "Converting Jupyter Notebooks" category = "main" optional = true @@ -915,7 +918,7 @@ webpdf = ["pyppeteer (>=1,<1.1)"] [[package]] name = "nbformat" -version = "5.6.0" +version = "5.6.1" description = "The Jupyter Notebook format" category = "main" optional = true @@ -1450,7 +1453,7 @@ sqlcipher = ["sqlcipher3-binary"] [[package]] name = "stack-data" -version = "0.5.0" +version = "0.5.1" description = "Extract data from python stack frames and tracebacks for informative displays" category = "main" optional = false @@ -1618,7 +1621,7 @@ docs = ["mkdocs", "mkdocs-material", "mkdocs-jupyter", "mkdocstrings", "mkdocstr [metadata] lock-version = "1.1" python-versions = "^3.8" -content-hash = "dbdf7f19c805328ea053818514b4cca3e23c8f498c8a557436e4ec817d82e922" +content-hash = "7179986a7629ad639acc44a7dace81a1cf06a20995890b96e29e720e9600da08" [metadata.files] appnope = [