Skip to content

Commit

Permalink
use duckdb instead of sqlite (#52)
Browse files Browse the repository at this point in the history
* add duckdb
* add date/metric groupby filter fixtures
* handle storing .index better to accomodate groupby/multiindex
* add tests for push-down filtering
  • Loading branch information
shouples authored Sep 26, 2022
1 parent 250f35a commit e6886ad
Show file tree
Hide file tree
Showing 10 changed files with 378 additions and 862 deletions.
968 changes: 160 additions & 808 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ mkdocs-material = {version = "^8.3.9", optional = true}
mkdocs-jupyter = {version = ">=0.21,<0.23", optional = true}
mkdocstrings = {version = "^0.19.0", optional = true}
mkdocstrings-python = {version = "^0.7.1", optional = true}
duckdb = "^0.5.1"
duckdb-engine = "^0.6.4"

[tool.poetry.dev-dependencies]
black = "^22.6.0"
Expand Down
30 changes: 20 additions & 10 deletions src/dx/filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
import pandas as pd
import structlog
from IPython.display import update_display
from IPython.terminal.interactiveshell import InteractiveShell

from dx.sampling import get_df_dimensions
from dx.settings import get_settings, settings_context
from dx.types import DEXFilterSettings, DEXResampleMessage
from dx.utils.tracking import DXDF_CACHE, SUBSET_TO_DISPLAY_ID, generate_df_hash, sql_engine
from dx.utils.tracking import DXDF_CACHE, SUBSET_TO_DISPLAY_ID, generate_df_hash, get_db_connection

logger = structlog.get_logger(__name__)

db_connection = get_db_connection()
settings = get_settings()


Expand Down Expand Up @@ -52,6 +53,7 @@ def update_display_id(
output_variable_name: Optional[str] = None,
limit: Optional[int] = None,
cell_id: Optional[str] = None,
ipython_shell: Optional[InteractiveShell] = None,
) -> None:
"""
Filters the dataframe in the cell with the given display_id.
Expand All @@ -65,21 +67,26 @@ def update_display_id(
row_limit = limit or settings.DISPLAY_MAX_ROWS
dxdf = DXDF_CACHE[display_id]

query_string = sql_filter.format(table_name=dxdf.sql_table)
query_string = sql_filter.format(table_name=dxdf.variable_name)
logger.debug(f"sql query string: {query_string}")
new_df = pd.read_sql(query_string, sql_engine)

with sql_engine.connect() as conn:
orig_df_count = conn.execute(f"SELECT COUNT (*) FROM {dxdf.sql_table}").scalar()
new_df: pd.DataFrame = db_connection.execute(query_string).df()
count_resp = db_connection.execute(f"SELECT COUNT(*) FROM {dxdf.variable_name}").fetchone()
# should return a tuple of (count,)
orig_df_count = count_resp[0]
logger.debug(f"filtered to {len(new_df)}/{orig_df_count} row(s)")

metadata = store_sample_to_history(new_df, display_id=display_id, filters=filters)

# resetting original index
new_df.set_index(dxdf.index_name, inplace=True)
# resetting original index if needed
if dxdf.index_name is not None:
new_df.set_index(dxdf.index_name, inplace=True)

# convert back to original dtypes
for col, dtype in dxdf.original_column_dtypes.items():
if settings.FLATTEN_COLUMN_VALUES and isinstance(col, tuple):
# the dataframe in use originally had pd.MultiIndex columns
col = ", ".join(col)
new_df[col] = new_df[col].astype(dtype)

# this is associating the subset with the original dataframe,
Expand All @@ -104,7 +111,10 @@ def update_display_id(
)


def handle_resample(msg: DEXResampleMessage) -> None:
def handle_resample(
msg: DEXResampleMessage,
ipython_shell: Optional[InteractiveShell] = None,
) -> None:
raw_filters = msg.filters
sample_size = msg.limit

Expand All @@ -130,4 +140,4 @@ def handle_resample(msg: DEXResampleMessage) -> None:
}
)

update_display_id(**update_params)
update_display_id(ipython_shell=ipython_shell, **update_params)
7 changes: 4 additions & 3 deletions src/dx/formatters/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
from dx.types import DXDisplayMode
from dx.utils.datatypes import to_dataframe
from dx.utils.formatting import generate_metadata, is_default_index, normalize_index_and_columns
from dx.utils.tracking import DXDF_CACHE, SUBSET_TO_DISPLAY_ID, DXDataFrame, store_in_sqlite
from dx.utils.tracking import DXDF_CACHE, SUBSET_TO_DISPLAY_ID, DXDataFrame, get_db_connection

logger = structlog.get_logger(__name__)

db_connection = get_db_connection()

DEFAULT_IPYTHON_DISPLAY_FORMATTER = DisplayFormatter()
IN_NOTEBOOK_ENV = False
Expand Down Expand Up @@ -57,7 +57,8 @@ def datalink_processing(
# this needs to happen after sending to the frontend
# so the user doesn't wait as long for writing larger datasets
if not parent_display_id:
store_in_sqlite(dxdf.sql_table, dxdf.df)
logger.debug(f"registering `{dxdf.variable_name}` to duckdb")
db_connection.register(dxdf.variable_name, dxdf.df.reset_index())

return payload, metadata

Expand Down
1 change: 1 addition & 0 deletions src/dx/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class Settings(BaseSettings):
# controls dataframe variable tracking, hashing, and storing in sqlite
ENABLE_DATALINK: bool = True
NUM_PAST_SAMPLES_TRACKED: int = 3
DB_LOCATION: str = ":memory:"

@validator("RENDERABLE_OBJECTS", pre=True, always=True)
def validate_renderables(cls, vals):
Expand Down
43 changes: 16 additions & 27 deletions src/dx/utils/tracking.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,19 @@
import hashlib
import uuid
from typing import List, Optional
from functools import lru_cache
from typing import List, Optional, Union

import duckdb
import pandas as pd
import structlog
from IPython import get_ipython
from IPython.core.interactiveshell import InteractiveShell
from pandas.util import hash_pandas_object
from sqlalchemy import create_engine

from dx.settings import get_settings
from dx.utils.datatypes import has_numeric_strings, is_sequence_series
from dx.utils.date_time import is_datetime_series
from dx.utils.formatting import generate_metadata, is_default_index, normalize_index_and_columns

logger = structlog.get_logger(__name__)
sql_engine = create_engine("sqlite://", echo=False)
settings = get_settings()


Expand All @@ -27,6 +25,11 @@
SUBSET_TO_DISPLAY_ID = {}


@lru_cache
def get_db_connection() -> duckdb.DuckDBPyConnection:
return duckdb.connect(database=settings.DB_LOCATION, read_only=False)


class DXDataFrame:
"""
Convenience class to store information about dataframes,
Expand Down Expand Up @@ -59,17 +62,12 @@ def __init__(
self.variable_name = get_df_variable_name(df, ipython_shell=ipython_shell)

self.original_column_dtypes = df.dtypes.to_dict()
self.sequence_columns = [column for column in df.columns if is_sequence_series(df[column])]
self.datetime_columns = [
c for c in df.columns if is_datetime_series(df[c]) and not has_numeric_strings(df[c])
]

self.default_index_used = is_default_index(df.index)
self.index_name = df.index.name or "index"
self.index_name = get_df_index(df.index)

self.df = normalize_index_and_columns(df)
self.hash = generate_df_hash(self.df)
self.sql_table = f"{self.variable_name}_{self.hash}"
self.display_id = SUBSET_TO_DISPLAY_ID.get(self.hash, str(uuid.uuid4()))

self.metadata = generate_metadata(self.display_id)
Expand Down Expand Up @@ -122,6 +120,13 @@ def generate_df_hash(df: pd.DataFrame) -> str:
return hash_str


def get_df_index(index: Union[pd.Index, pd.MultiIndex]):
index_name = index.name
if index_name is None and isinstance(index, pd.MultiIndex):
index_name = index.names
return index_name


def get_df_variable_name(
df: pd.DataFrame,
ipython_shell: Optional[InteractiveShell] = None,
Expand Down Expand Up @@ -169,19 +174,3 @@ def get_df_variable_name(
logger.debug("no variables found matching this dataframe")
df_uuid = f"unk_dataframe_{uuid.uuid4()}".replace("-", "")
return df_uuid


def store_in_sqlite(table_name: str, df: pd.DataFrame):
logger.debug(f"{df.columns=}")
tracking_df = df.copy()

logger.debug(f"writing to `{table_name}` table in sqlite")
with sql_engine.begin() as conn:
num_written_rows = tracking_df.to_sql(
table_name,
con=conn,
if_exists="replace",
index=True, # this is the default, but just to be explicit
)
logger.debug(f"wrote {num_written_rows} row(s) to `{table_name}` table")
return num_written_rows
46 changes: 43 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import duckdb
import numpy as np
import pandas as pd
import pytest
Expand Down Expand Up @@ -38,6 +39,13 @@ def get_ipython() -> TerminalInteractiveShell:
return shell


@pytest.fixture
def sample_db_connection() -> duckdb.DuckDBPyConnection:
conn = duckdb.connect(":memory:")
yield conn
conn.close()


@pytest.fixture
def sample_dataframe() -> pd.DataFrame:
df = pd.DataFrame(
Expand Down Expand Up @@ -106,7 +114,18 @@ def sample_long_wide_dataframe() -> pd.DataFrame:
@pytest.fixture
def sample_dex_date_filter(sample_random_dataframe: pd.DataFrame) -> dict:
return {
"column": "date_column",
"column": "datetime_column",
"type": "DATE_FILTER",
"predicate": "between",
"start": sample_random_dataframe.datetime_column.min(),
"end": sample_random_dataframe.datetime_column.max(),
}


@pytest.fixture
def sample_dex_groupby_date_filter(sample_random_dataframe: pd.DataFrame) -> dict:
return {
"column": "datetime_column, min",
"type": "DATE_FILTER",
"predicate": "between",
"start": sample_random_dataframe.datetime_column.min(),
Expand All @@ -127,6 +146,19 @@ def sample_dex_metric_filter(sample_random_dataframe: pd.DataFrame) -> dict:
}


@pytest.fixture
def sample_dex_groupby_metric_filter(sample_random_dataframe: pd.DataFrame) -> dict:
return {
"column": "float_column, min",
"type": "METRIC_FILTER",
"predicate": "between",
"value": [
sample_random_dataframe.float_column.min(),
sample_random_dataframe.float_column.max(),
],
}


@pytest.fixture
def sample_dex_dimension_filter(sample_random_dataframe: pd.DataFrame) -> dict:
return {
Expand All @@ -153,8 +185,16 @@ def sample_dex_filters(


@pytest.fixture
def sample_display_id() -> str:
return "test-display-123"
def sample_dex_groupby_filters(
sample_dex_groupby_date_filter: dict,
sample_dex_groupby_metric_filter: dict,
) -> list:
return DEXFilterSettings(
filters=[
sample_dex_groupby_date_filter,
sample_dex_groupby_metric_filter,
]
).filters


@pytest.fixture
Expand Down
26 changes: 16 additions & 10 deletions tests/test_datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
Tests to ensure various data types can be sent functions to
- build the table schema and payload/metadata body for each display formatter
- hash the dataframe for tracking
- write to sqlite for tracking/filtering
- write to the database for tracking/filtering
"""

import duckdb
import pandas as pd
import pytest
from pandas.io.json import build_table_schema
Expand All @@ -19,7 +20,7 @@
random_dataframe,
)
from dx.utils.formatting import clean_column_values
from dx.utils.tracking import generate_df_hash, sql_engine, store_in_sqlite
from dx.utils.tracking import generate_df_hash


@pytest.mark.parametrize("dtype", SORTED_DX_DATATYPES)
Expand Down Expand Up @@ -135,37 +136,42 @@ def test_generate_df_hash(dtype: str):

@pytest.mark.xfail(reason="only for dev")
@pytest.mark.parametrize("dtype", SORTED_DX_DATATYPES)
def test_to_sql(dtype: str):
def test_to_sql(dtype: str, sample_db_connection: duckdb.DuckDBPyConnection):
"""
DEV: Test which data types pass/fail when passed directly through .to_sql()
with the sqlalchemy engine.
"""
params = {dt: False for dt in SORTED_DX_DATATYPES}
params[dtype] = True

df = random_dataframe(**params)

try:
with sql_engine.connect() as conn:
num_rows = df.to_sql("test", conn, if_exists="replace")
sample_db_connection.register(f"{dtype}_test", df)
except Exception as e:
assert False, f"{dtype} failed with {e}"
count_resp = sample_db_connection.execute(f"SELECT COUNT(*) FROM {dtype}_test").fetchone()
num_rows = count_resp[0]
assert num_rows == df.shape[0]


@pytest.mark.parametrize("dtype", SORTED_DX_DATATYPES)
def test_store_in_sqlite(dtype: str):
def test_store_in_db(dtype: str, sample_db_connection: duckdb.DuckDBPyConnection):
"""
Test that we've correctly handled data types before storing in sqlite.
"""
params = {dt: False for dt in SORTED_DX_DATATYPES}
params[dtype] = True

df = random_dataframe(**params)

for col in df.columns:
df[col] = clean_column_values(df[col])

try:
num_rows = store_in_sqlite(f"{dtype}_test", df)
sample_db_connection.register(f"{dtype}_test", df)
except Exception as e:
assert False, f"{dtype} failed with {e}"
count_resp = sample_db_connection.execute(f"SELECT COUNT(*) FROM {dtype}_test").fetchone()
num_rows = count_resp[0]
assert num_rows == df.shape[0]


# TODO: test that we can convert back to original datatypes after read_sql?
Loading

0 comments on commit e6886ad

Please sign in to comment.