-
Notifications
You must be signed in to change notification settings - Fork 1
/
csv_table.py
81 lines (57 loc) · 2.54 KB
/
csv_table.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import typing
import pandas as pd
import pydantic
from retrack.nodes.base import InputConnectionModel, OutputConnectionModel
from retrack.nodes.dynamic.base import BaseDynamicIOModel, BaseDynamicNode
class CSVTableV0MetadataModel(pydantic.BaseModel):
value: typing.Union[typing.List[str], typing.List[typing.List[str]]]
target: str
headers: typing.List[str]
headers_map: typing.List[str]
separator: typing.Optional[str] = ","
default: typing.Optional[str] = None
def df(self) -> pd.DataFrame:
rows = [values.split(self.separator) for values in self.value[1:]]
return pd.DataFrame(rows, columns=self.headers_map)
class CSVTableV0OutputsModel(pydantic.BaseModel):
output_value: OutputConnectionModel
def csv_table_factory(
inputs: typing.Dict[str, typing.Any], **kwargs
) -> typing.Type[BaseDynamicNode]:
input_fields = {}
for name in inputs.keys():
input_fields[name] = BaseDynamicNode.create_sub_field(InputConnectionModel)
inputs_model = BaseDynamicIOModel.with_fields(
"CSVTableV0InputsModel", **input_fields
)
models = {
"inputs": BaseDynamicNode.create_sub_field(inputs_model),
"outputs": BaseDynamicNode.create_sub_field(CSVTableV0OutputsModel),
"data": BaseDynamicNode.create_sub_field(CSVTableV0MetadataModel),
}
BaseCSVTableV0Model = BaseDynamicNode.with_fields("CSVTableV0", **models)
class CSVTableV0(BaseCSVTableV0Model):
def run(self, **kwargs) -> typing.Dict[str, typing.Any]:
csv_df = self.data.df()
response_df = {}
input_columns = [
name for name in self.data.headers_map if name != self.data.target
]
for name in input_columns:
if name == self.data.target:
continue
if name not in kwargs.keys():
raise ValueError(f"Missing input {name} in CSVTableV0 node")
response_df[name] = kwargs[name]
response_df = pd.DataFrame(response_df)
response_df = response_df.astype(str)
response_df = response_df.merge(csv_df, how="left", on=input_columns)
if self.data.default:
response_df[self.data.target] = response_df[self.data.target].fillna(
self.data.default
)
response_df.set_index(
kwargs[input_columns[0]].index, inplace=True, drop=True
)
return {"output_value": response_df[self.data.target]}
return CSVTableV0