Skip to content

Commit

Permalink
feat(datasets): Extend preview mechanism (kedro-org#595)
Browse files Browse the repository at this point in the history
* Extend preview to Parquet

* Update sql_dataset.py

* Update sql_dataset.py

* update preview method for parquetdataset

* Update sql_dataset.py

* extend preview to JSONDataset

* add json preview

* add preview for pickledataset

* Update json_dataset.py

* lint

* add tests for parquet and json

* lint

* rem pickle fix docstring

* Fix parquet test

* fix pandas.json tests

* add coverage for sqldataset

* lint

* coverage for sanitisation of sql

* changes based on review

* use pyarrow for parquet preview

* align jsondataset with spike

* Update json_dataset.py

* Update json_dataset.py

* pass lines=true and nrows

* update docstring

* Update test_json_dataset.py

* revert change

* use sqlalchemy instead of query

* fix sql tests

Signed-off-by: tgoelles <[email protected]>
  • Loading branch information
SajidAlamQB authored and tgoelles committed Jun 6, 2024
1 parent 9c1bfa3 commit 2e96c9f
Show file tree
Hide file tree
Showing 8 changed files with 220 additions and 1 deletion.
13 changes: 13 additions & 0 deletions kedro-datasets/kedro_datasets/json/json_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
get_protocol_and_path,
)

from kedro_datasets._typing import JSONPreview


class JSONDataset(AbstractVersionedDataset[Any, Any]):
"""``JSONDataset`` loads/saves data from/to a JSON file using an underlying
Expand Down Expand Up @@ -160,3 +162,14 @@ def _invalidate_cache(self) -> None:
"""Invalidate underlying filesystem caches."""
filepath = get_filepath_str(self._filepath, self._protocol)
self._fs.invalidate_cache(filepath)

def preview(self) -> JSONPreview:
"""
Generate a preview of the JSON dataset with a specified number of items.
Returns:
A string representing the JSON data for previewing.
"""
data = self._load()

return json.dumps(data)
23 changes: 23 additions & 0 deletions kedro-datasets/kedro_datasets/pandas/json_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
get_protocol_and_path,
)

from kedro_datasets._typing import TablePreview

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -188,3 +190,24 @@ def _invalidate_cache(self) -> None:
"""Invalidate underlying filesystem caches."""
filepath = get_filepath_str(self._filepath, self._protocol)
self._fs.invalidate_cache(filepath)

def preview(self, nrows: int = 5) -> TablePreview:
"""
Generate a preview of the dataset with a specified number of rows,
including handling for both flat and nested JSON structures.
Args:
nrows: Number of rows to include in the preview. Defaults to 5.
Returns:
dict: A dictionary in a split format for preview, if possible.
"""
# Create a copy, so it doesn't contaminate the original dataset
dataset_copy = self._copy()
dataset_copy._load_args.setdefault("lines", True)
dataset_copy._load_args["nrows"] = nrows
preview_df = dataset_copy._load()

preview_dict = preview_df.to_dict(orient="split")

return preview_dict
23 changes: 23 additions & 0 deletions kedro-datasets/kedro_datasets/pandas/parquet_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
get_protocol_and_path,
)

from kedro_datasets._typing import TablePreview

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -214,3 +216,24 @@ def _invalidate_cache(self) -> None:
"""Invalidate underlying filesystem caches."""
filepath = get_filepath_str(self._filepath, self._protocol)
self._fs.invalidate_cache(filepath)

def preview(self, nrows: int = 5) -> TablePreview:
"""
Generate a preview of the dataset with a specified number of rows.
Args:
nrows: The number of rows to include in the preview. Defaults to 5.
Returns:
dict: A dictionary containing the data in a split format.
"""
import pyarrow.parquet as pq

load_path = str(self._get_load_path())

table = pq.read_table(
load_path, columns=self._load_args.get("columns"), use_threads=True
)[:nrows]
data_preview = table.to_pandas()

return data_preview.to_dict(orient="split")
29 changes: 28 additions & 1 deletion kedro-datasets/kedro_datasets/pandas/sql_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
get_filepath_str,
get_protocol_and_path,
)
from sqlalchemy import create_engine, inspect
from sqlalchemy import MetaData, Table, create_engine, inspect, select
from sqlalchemy.exc import NoSuchModuleError

from kedro_datasets._typing import TablePreview

__all__ = ["SQLTableDataset", "SQLQueryDataset"]

KNOWN_PIP_INSTALL = {
Expand Down Expand Up @@ -275,6 +277,31 @@ def _exists(self) -> bool:
schema = self._load_args.get("schema", None)
return insp.has_table(self._load_args["table_name"], schema)

def preview(self, nrows: int = 5) -> TablePreview:
"""
Generate a preview of the dataset with a specified number of rows.
Args:
nrows: The number of rows to include in the preview. Defaults to 5.
Returns:
dict: A dictionary containing the data in a split format.
"""

table_name = self._load_args["table_name"]

metadata = MetaData()
table_ref = Table(table_name, metadata, autoload_with=self.engine)

query = select(table_ref).limit(nrows)

with self.engine.connect() as conn:
result = conn.execute(query)
data_preview = pd.DataFrame(result.fetchall(), columns=result.keys())

preview_data = data_preview.to_dict(orient="split")
return preview_data


class SQLQueryDataset(AbstractDataset[None, pd.DataFrame]):
"""``SQLQueryDataset`` loads data from a provided SQL query. It
Expand Down
21 changes: 21 additions & 0 deletions kedro-datasets/tests/json/test_json_dataset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import inspect
import json
from pathlib import Path, PurePosixPath

import pytest
Expand Down Expand Up @@ -197,3 +199,22 @@ def test_versioning_existing_dataset(
Path(json_dataset._filepath.as_posix()).unlink()
versioned_json_dataset.save(dummy_data)
assert versioned_json_dataset.exists()

def test_preview(self, json_dataset, dummy_data):
"""Test the preview method."""
json_dataset.save(dummy_data)
preview_data = json_dataset.preview()

# Load the data directly for comparison
with json_dataset._fs.open(json_dataset._get_load_path(), mode="r") as fs_file:
full_data = json.load(fs_file)

expected_data = json.dumps(full_data)

assert (
preview_data == expected_data
), "The preview data does not match the expected data."
assert (
inspect.signature(json_dataset.preview).return_annotation.__name__
== "JSONPreview"
)
37 changes: 37 additions & 0 deletions kedro-datasets/tests/pandas/test_json_dataset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import inspect
import json
from pathlib import Path, PurePosixPath

import pandas as pd
Expand Down Expand Up @@ -40,6 +42,17 @@ def dummy_dataframe():
return pd.DataFrame({"col1": [1, 2], "col2": [4, 5], "col3": [5, 6]})


@pytest.fixture
def json_lines_data(tmp_path):
data = [
{"name": "Alice", "age": 30, "city": "New York"},
{"name": "Bob", "age": 25, "city": "Los Angeles"},
]
filepath = tmp_path / "lines_test.json"
filepath.write_text("\n".join(json.dumps(item) for item in data))
return filepath.as_posix()


class TestJSONDataset:
def test_save_and_load(self, json_dataset, dummy_dataframe):
"""Test saving and reloading the data set."""
Expand Down Expand Up @@ -142,6 +155,30 @@ def test_catalog_release(self, mocker):
dataset.release()
fs_mock.invalidate_cache.assert_called_once_with(filepath)

def test_preview_json(self, json_lines_data):
dataset = JSONDataset(filepath=json_lines_data, load_args={"lines": True})
preview_data = dataset.preview(nrows=2)
expected_columns = ["name", "age", "city"]
expected_data = [["Alice", 30, "New York"], ["Bob", 25, "Los Angeles"]]

assert preview_data["columns"] == expected_columns
assert preview_data["data"] == expected_data
assert len(preview_data["data"]) == 2
assert (
inspect.signature(dataset.preview).return_annotation.__name__
== "TablePreview"
)

def test_preview_json_lines(self, json_dataset, json_lines_data):
json_dataset._filepath = json_lines_data
json_dataset._load_args = {"lines": True}
preview_data = json_dataset.preview()
assert len(preview_data["data"]) == 2
assert (
inspect.signature(json_dataset.preview).return_annotation.__name__
== "TablePreview"
)


class TestJSONDatasetVersioned:
def test_version_str_repr(self, load_version, save_version):
Expand Down
38 changes: 38 additions & 0 deletions kedro-datasets/tests/pandas/test_parquet_dataset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
from pathlib import Path, PurePosixPath

import pandas as pd
Expand Down Expand Up @@ -42,6 +43,17 @@ def dummy_dataframe():
return pd.DataFrame({"col1": [1, 2], "col2": [4, 5], "col3": [5, 6]})


@pytest.fixture
def dummy_dataframe_preview():
return pd.DataFrame(
{
"col1": [1, 2, 3, 4, 5, 6],
"col2": [4, 5, 6, 7, 8, 9],
"col3": [5, 6, 7, 8, 9, 10],
}
)


class TestParquetDataset:
def test_credentials_propagated(self, mocker):
"""Test propagating credentials for connecting to GCS"""
Expand Down Expand Up @@ -213,6 +225,32 @@ def test_arg_partition_cols(self, dummy_dataframe, tmp_path):
with pytest.raises(DatasetError, match=pattern):
dataset.save(dummy_dataframe)

@pytest.mark.parametrize(
"nrows,expected_rows",
[
(5, 5),
(
10,
6,
), # Test with specified rows more than total, assuming 6 rows in dummy data
],
)
def test_preview(
self, parquet_dataset, dummy_dataframe_preview, nrows, expected_rows
):
"""Test the preview functionality for ParquetDataset."""
parquet_dataset.save(dummy_dataframe_preview)
previewed_data = parquet_dataset.preview(nrows=nrows)

# Assert preview data matches expected rows
assert len(previewed_data["data"]) == expected_rows
# Assert columns match
assert previewed_data["columns"] == list(dummy_dataframe_preview.columns)
assert (
inspect.signature(parquet_dataset.preview).return_annotation.__name__
== "TablePreview"
)


class TestParquetDatasetVersioned:
def test_version_str_repr(self, load_version, save_version):
Expand Down
37 changes: 37 additions & 0 deletions kedro-datasets/tests/pandas/test_sql_dataset.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import inspect
from pathlib import PosixPath
from unittest.mock import ANY

import pandas as pd
import pytest
import sqlalchemy
from kedro.io.core import DatasetError
from sqlalchemy.exc import SQLAlchemyError

import kedro_datasets
from kedro_datasets.pandas import SQLQueryDataset, SQLTableDataset
Expand Down Expand Up @@ -60,6 +62,19 @@ def query_file_dataset(request, sql_file):
return SQLQueryDataset(**kwargs)


@pytest.fixture
def sql_dataset(tmp_path):
connection_string = "sqlite:///:memory:"
table_name = "test_table"

engine = sqlalchemy.create_engine(connection_string)
test_data = pd.DataFrame({"col1": [1, 2, 3], "col2": ["a", "b", "c"]})
test_data.to_sql(table_name, engine, index=False)

credentials = {"con": connection_string}
return SQLTableDataset(table_name=table_name, credentials=credentials)


class TestSQLTableDataset:
_unknown_conn = "mysql+unknown_module://scott:tiger@localhost/foo"

Expand Down Expand Up @@ -211,6 +226,28 @@ def test_additional_params(self, mocker):
CONNECTION, **additional_params
)

def test_preview_normal_scenario(self, sql_dataset):
engine = sql_dataset.engine
expected_df = pd.DataFrame({"col1": [1, 2, 3], "col2": ["a", "b", "c"]})
expected_df.to_sql(sql_dataset._load_args["table_name"], engine, index=False)
preview = sql_dataset.preview(nrows=3)

assert "columns" in preview
assert "data" in preview
assert len(preview["data"]) == len(expected_df)

return_annotation = inspect.signature(sql_dataset.preview).return_annotation
assert return_annotation == "TablePreview"

def test_preview_sql_error(self, table_dataset, mocker):
mocker.patch(
"pandas.read_sql_query",
side_effect=SQLAlchemyError("Mocked SQL error", "", ""),
)

with pytest.raises(SQLAlchemyError):
table_dataset.preview(nrows=3)


class TestSQLTableDatasetSingleConnection:
def test_single_connection(self, dummy_dataframe, mocker):
Expand Down

0 comments on commit 2e96c9f

Please sign in to comment.