Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use duckdb instead of sqlite #52

Merged
merged 9 commits into from
Sep 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
968 changes: 160 additions & 808 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ mkdocs-material = {version = "^8.3.9", optional = true}
mkdocs-jupyter = {version = ">=0.21,<0.23", optional = true}
mkdocstrings = {version = "^0.19.0", optional = true}
mkdocstrings-python = {version = "^0.7.1", optional = true}
duckdb = "^0.5.1"
duckdb-engine = "^0.6.4"

[tool.poetry.dev-dependencies]
black = "^22.6.0"
Expand Down
30 changes: 20 additions & 10 deletions src/dx/filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
import pandas as pd
import structlog
from IPython.display import update_display
from IPython.terminal.interactiveshell import InteractiveShell

from dx.sampling import get_df_dimensions
from dx.settings import get_settings, settings_context
from dx.types import DEXFilterSettings, DEXResampleMessage
from dx.utils.tracking import DXDF_CACHE, SUBSET_TO_DISPLAY_ID, generate_df_hash, sql_engine
from dx.utils.tracking import DXDF_CACHE, SUBSET_TO_DISPLAY_ID, generate_df_hash, get_db_connection

logger = structlog.get_logger(__name__)

db_connection = get_db_connection()
settings = get_settings()


Expand Down Expand Up @@ -52,6 +53,7 @@ def update_display_id(
output_variable_name: Optional[str] = None,
limit: Optional[int] = None,
cell_id: Optional[str] = None,
ipython_shell: Optional[InteractiveShell] = None,
) -> None:
"""
Filters the dataframe in the cell with the given display_id.
Expand All @@ -65,21 +67,26 @@ def update_display_id(
row_limit = limit or settings.DISPLAY_MAX_ROWS
dxdf = DXDF_CACHE[display_id]

query_string = sql_filter.format(table_name=dxdf.sql_table)
query_string = sql_filter.format(table_name=dxdf.variable_name)
logger.debug(f"sql query string: {query_string}")
new_df = pd.read_sql(query_string, sql_engine)

with sql_engine.connect() as conn:
orig_df_count = conn.execute(f"SELECT COUNT (*) FROM {dxdf.sql_table}").scalar()
new_df: pd.DataFrame = db_connection.execute(query_string).df()
count_resp = db_connection.execute(f"SELECT COUNT(*) FROM {dxdf.variable_name}").fetchone()
# should return a tuple of (count,)
orig_df_count = count_resp[0]
logger.debug(f"filtered to {len(new_df)}/{orig_df_count} row(s)")

metadata = store_sample_to_history(new_df, display_id=display_id, filters=filters)

# resetting original index
new_df.set_index(dxdf.index_name, inplace=True)
# resetting original index if needed
if dxdf.index_name is not None:
new_df.set_index(dxdf.index_name, inplace=True)

# convert back to original dtypes
for col, dtype in dxdf.original_column_dtypes.items():
if settings.FLATTEN_COLUMN_VALUES and isinstance(col, tuple):
# the dataframe in use originally had pd.MultiIndex columns
col = ", ".join(col)
new_df[col] = new_df[col].astype(dtype)

# this is associating the subset with the original dataframe,
Expand All @@ -104,7 +111,10 @@ def update_display_id(
)


def handle_resample(msg: DEXResampleMessage) -> None:
def handle_resample(
msg: DEXResampleMessage,
ipython_shell: Optional[InteractiveShell] = None,
) -> None:
raw_filters = msg.filters
sample_size = msg.limit

Expand All @@ -130,4 +140,4 @@ def handle_resample(msg: DEXResampleMessage) -> None:
}
)

update_display_id(**update_params)
update_display_id(ipython_shell=ipython_shell, **update_params)
7 changes: 4 additions & 3 deletions src/dx/formatters/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
from dx.types import DXDisplayMode
from dx.utils.datatypes import to_dataframe
from dx.utils.formatting import generate_metadata, is_default_index, normalize_index_and_columns
from dx.utils.tracking import DXDF_CACHE, SUBSET_TO_DISPLAY_ID, DXDataFrame, store_in_sqlite
from dx.utils.tracking import DXDF_CACHE, SUBSET_TO_DISPLAY_ID, DXDataFrame, get_db_connection

logger = structlog.get_logger(__name__)

db_connection = get_db_connection()

DEFAULT_IPYTHON_DISPLAY_FORMATTER = DisplayFormatter()
IN_NOTEBOOK_ENV = False
Expand Down Expand Up @@ -57,7 +57,8 @@ def datalink_processing(
# this needs to happen after sending to the frontend
# so the user doesn't wait as long for writing larger datasets
if not parent_display_id:
store_in_sqlite(dxdf.sql_table, dxdf.df)
logger.debug(f"registering `{dxdf.variable_name}` to duckdb")
db_connection.register(dxdf.variable_name, dxdf.df.reset_index())

return payload, metadata

Expand Down
1 change: 1 addition & 0 deletions src/dx/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class Settings(BaseSettings):
# controls dataframe variable tracking, hashing, and storing in sqlite
ENABLE_DATALINK: bool = True
NUM_PAST_SAMPLES_TRACKED: int = 3
DB_LOCATION: str = ":memory:"

@validator("RENDERABLE_OBJECTS", pre=True, always=True)
def validate_renderables(cls, vals):
Expand Down
43 changes: 16 additions & 27 deletions src/dx/utils/tracking.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,19 @@
import hashlib
import uuid
from typing import List, Optional
from functools import lru_cache
from typing import List, Optional, Union

import duckdb
import pandas as pd
import structlog
from IPython import get_ipython
from IPython.core.interactiveshell import InteractiveShell
from pandas.util import hash_pandas_object
from sqlalchemy import create_engine

from dx.settings import get_settings
from dx.utils.datatypes import has_numeric_strings, is_sequence_series
from dx.utils.date_time import is_datetime_series
from dx.utils.formatting import generate_metadata, is_default_index, normalize_index_and_columns

logger = structlog.get_logger(__name__)
sql_engine = create_engine("sqlite://", echo=False)
settings = get_settings()


Expand All @@ -27,6 +25,11 @@
SUBSET_TO_DISPLAY_ID = {}


@lru_cache
def get_db_connection() -> duckdb.DuckDBPyConnection:
return duckdb.connect(database=settings.DB_LOCATION, read_only=False)


class DXDataFrame:
"""
Convenience class to store information about dataframes,
Expand Down Expand Up @@ -59,17 +62,12 @@ def __init__(
self.variable_name = get_df_variable_name(df, ipython_shell=ipython_shell)

self.original_column_dtypes = df.dtypes.to_dict()
self.sequence_columns = [column for column in df.columns if is_sequence_series(df[column])]
self.datetime_columns = [
c for c in df.columns if is_datetime_series(df[c]) and not has_numeric_strings(df[c])
]

self.default_index_used = is_default_index(df.index)
self.index_name = df.index.name or "index"
self.index_name = get_df_index(df.index)

self.df = normalize_index_and_columns(df)
self.hash = generate_df_hash(self.df)
self.sql_table = f"{self.variable_name}_{self.hash}"
self.display_id = SUBSET_TO_DISPLAY_ID.get(self.hash, str(uuid.uuid4()))

self.metadata = generate_metadata(self.display_id)
Expand Down Expand Up @@ -122,6 +120,13 @@ def generate_df_hash(df: pd.DataFrame) -> str:
return hash_str


def get_df_index(index: Union[pd.Index, pd.MultiIndex]):
index_name = index.name
if index_name is None and isinstance(index, pd.MultiIndex):
index_name = index.names
return index_name


def get_df_variable_name(
df: pd.DataFrame,
ipython_shell: Optional[InteractiveShell] = None,
Expand Down Expand Up @@ -169,19 +174,3 @@ def get_df_variable_name(
logger.debug("no variables found matching this dataframe")
df_uuid = f"unk_dataframe_{uuid.uuid4()}".replace("-", "")
return df_uuid


def store_in_sqlite(table_name: str, df: pd.DataFrame):
logger.debug(f"{df.columns=}")
tracking_df = df.copy()

logger.debug(f"writing to `{table_name}` table in sqlite")
with sql_engine.begin() as conn:
num_written_rows = tracking_df.to_sql(
table_name,
con=conn,
if_exists="replace",
index=True, # this is the default, but just to be explicit
)
logger.debug(f"wrote {num_written_rows} row(s) to `{table_name}` table")
return num_written_rows
46 changes: 43 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import duckdb
import numpy as np
import pandas as pd
import pytest
Expand Down Expand Up @@ -38,6 +39,13 @@ def get_ipython() -> TerminalInteractiveShell:
return shell


@pytest.fixture
def sample_db_connection() -> duckdb.DuckDBPyConnection:
conn = duckdb.connect(":memory:")
yield conn
conn.close()


@pytest.fixture
def sample_dataframe() -> pd.DataFrame:
df = pd.DataFrame(
Expand Down Expand Up @@ -106,7 +114,18 @@ def sample_long_wide_dataframe() -> pd.DataFrame:
@pytest.fixture
def sample_dex_date_filter(sample_random_dataframe: pd.DataFrame) -> dict:
return {
"column": "date_column",
"column": "datetime_column",
"type": "DATE_FILTER",
"predicate": "between",
"start": sample_random_dataframe.datetime_column.min(),
"end": sample_random_dataframe.datetime_column.max(),
}


@pytest.fixture
def sample_dex_groupby_date_filter(sample_random_dataframe: pd.DataFrame) -> dict:
return {
"column": "datetime_column, min",
"type": "DATE_FILTER",
"predicate": "between",
"start": sample_random_dataframe.datetime_column.min(),
Expand All @@ -127,6 +146,19 @@ def sample_dex_metric_filter(sample_random_dataframe: pd.DataFrame) -> dict:
}


@pytest.fixture
def sample_dex_groupby_metric_filter(sample_random_dataframe: pd.DataFrame) -> dict:
return {
"column": "float_column, min",
"type": "METRIC_FILTER",
"predicate": "between",
"value": [
sample_random_dataframe.float_column.min(),
sample_random_dataframe.float_column.max(),
],
}


@pytest.fixture
def sample_dex_dimension_filter(sample_random_dataframe: pd.DataFrame) -> dict:
return {
Expand All @@ -153,8 +185,16 @@ def sample_dex_filters(


@pytest.fixture
def sample_display_id() -> str:
return "test-display-123"
def sample_dex_groupby_filters(
sample_dex_groupby_date_filter: dict,
sample_dex_groupby_metric_filter: dict,
) -> list:
return DEXFilterSettings(
filters=[
sample_dex_groupby_date_filter,
sample_dex_groupby_metric_filter,
]
).filters


@pytest.fixture
Expand Down
26 changes: 16 additions & 10 deletions tests/test_datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
Tests to ensure various data types can be sent functions to
- build the table schema and payload/metadata body for each display formatter
- hash the dataframe for tracking
- write to sqlite for tracking/filtering
- write to the database for tracking/filtering
"""

import duckdb
import pandas as pd
import pytest
from pandas.io.json import build_table_schema
Expand All @@ -19,7 +20,7 @@
random_dataframe,
)
from dx.utils.formatting import clean_column_values
from dx.utils.tracking import generate_df_hash, sql_engine, store_in_sqlite
from dx.utils.tracking import generate_df_hash


@pytest.mark.parametrize("dtype", SORTED_DX_DATATYPES)
Expand Down Expand Up @@ -135,37 +136,42 @@ def test_generate_df_hash(dtype: str):

@pytest.mark.xfail(reason="only for dev")
@pytest.mark.parametrize("dtype", SORTED_DX_DATATYPES)
def test_to_sql(dtype: str):
def test_to_sql(dtype: str, sample_db_connection: duckdb.DuckDBPyConnection):
"""
DEV: Test which data types pass/fail when passed directly through .to_sql()
with the sqlalchemy engine.
"""
params = {dt: False for dt in SORTED_DX_DATATYPES}
params[dtype] = True

df = random_dataframe(**params)

try:
with sql_engine.connect() as conn:
num_rows = df.to_sql("test", conn, if_exists="replace")
sample_db_connection.register(f"{dtype}_test", df)
except Exception as e:
assert False, f"{dtype} failed with {e}"
count_resp = sample_db_connection.execute(f"SELECT COUNT(*) FROM {dtype}_test").fetchone()
num_rows = count_resp[0]
assert num_rows == df.shape[0]


@pytest.mark.parametrize("dtype", SORTED_DX_DATATYPES)
def test_store_in_sqlite(dtype: str):
def test_store_in_db(dtype: str, sample_db_connection: duckdb.DuckDBPyConnection):
"""
Test that we've correctly handled data types before storing in sqlite.
"""
params = {dt: False for dt in SORTED_DX_DATATYPES}
params[dtype] = True

df = random_dataframe(**params)

for col in df.columns:
df[col] = clean_column_values(df[col])

try:
num_rows = store_in_sqlite(f"{dtype}_test", df)
sample_db_connection.register(f"{dtype}_test", df)
except Exception as e:
assert False, f"{dtype} failed with {e}"
count_resp = sample_db_connection.execute(f"SELECT COUNT(*) FROM {dtype}_test").fetchone()
num_rows = count_resp[0]
assert num_rows == df.shape[0]


# TODO: test that we can convert back to original datatypes after read_sql?
Loading