-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathpolars_delta.py
117 lines (99 loc) · 3.84 KB
/
polars_delta.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
"""``CSVDataset`` loads/saves data from/to a CSV file using an underlying
filesystem (e.g.: local, S3, GCS). It uses polars to handle the CSV file.
"""
import logging
from copy import deepcopy
from pathlib import PurePosixPath
from typing import Any, Dict
import fsspec
import polars as pl
from deltalake.exceptions import TableNotFoundError
from kedro.io.core import (
PROTOCOL_DELIMITER,
AbstractVersionedDataset,
DatasetError,
Version,
get_filepath_str,
get_protocol_and_path,
)
logger = logging.getLogger(__name__)
class DeltaDataset(AbstractVersionedDataset[pl.DataFrame, pl.DataFrame]):
"""``DeltaDataset`` loads/saves data from/to a Delta Table using an underlying
filesystem (e.g.: local, S3, GCS). It returns a Polars dataframe.
"""
DEFAULT_LOAD_ARGS: Dict[str, Any] = {}
DEFAULT_SAVE_ARGS: Dict[str, Any] = {}
def __init__( # noqa: PLR0913
self,
filepath: str,
load_args: Dict[str, Any] = None,
save_args: Dict[str, Any] = None,
version: Version = None,
credentials: Dict[str, Any] = None,
fs_args: Dict[str, Any] = None,
metadata: Dict[str, Any] = None,
) -> None:
_fs_args = deepcopy(fs_args) or {}
_credentials = deepcopy(credentials) or {}
protocol, path = get_protocol_and_path(filepath, version)
if protocol == "file":
_fs_args.setdefault("auto_mkdir", True)
self._protocol = protocol
self._storage_options = {**_credentials, **_fs_args}
self._fs = fsspec.filesystem(self._protocol, **_fs_args)
self.metadata = metadata
super().__init__(
filepath=PurePosixPath(path),
version=version,
exists_function=self._fs.exists,
glob_function=self._fs.glob,
)
# Handle default load and save arguments
self._load_args = deepcopy(self.DEFAULT_LOAD_ARGS)
if load_args is not None:
self._load_args.update(load_args)
self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS)
if save_args is not None:
self._save_args.update(save_args)
if "storage_options" in self._save_args or "storage_options" in self._load_args:
logger.warning(
"Dropping 'storage_options' for %s, "
"please specify them under 'fs_args' or 'credentials'.",
self._filepath,
)
self._save_args.pop("storage_options", None)
self._load_args.pop("storage_options", None)
def _describe(self) -> Dict[str, Any]:
return {
"filepath": self._filepath,
"protocol": self._protocol,
"load_args": self._load_args,
"save_args": self._save_args,
"version": self._version,
}
def _load(self) -> pl.DataFrame:
load_path = str(self._get_load_path())
load_path = f"{self._protocol}{PROTOCOL_DELIMITER}{load_path}"
# HACK: If the table is empty, return an empty DataFrame
try:
return pl.read_delta(
load_path, storage_options=self._storage_options, **self._load_args
)
except TableNotFoundError:
return pl.DataFrame()
def _save(self, data: pl.DataFrame) -> None:
save_path = get_filepath_str(self._get_save_path(), self._protocol)
save_path = f"{self._protocol}{PROTOCOL_DELIMITER}{save_path}"
data.write_delta(
save_path, storage_options=self._storage_options, **self._save_args
)
def _exists(self) -> bool:
try:
load_path = f"{self._protocol}{PROTOCOL_DELIMITER}{self._get_load_path()}"
pl.read_delta(
load_path, storage_options=self._storage_options, **self._load_args
)
except DatasetError:
return False
else:
return True