diff --git a/sdgx/data_models/inspectors/bool.py b/sdgx/data_models/inspectors/bool.py index 53bc7b12..acf0e6a0 100644 --- a/sdgx/data_models/inspectors/bool.py +++ b/sdgx/data_models/inspectors/bool.py @@ -22,6 +22,7 @@ def fit(self, raw_data: pd.DataFrame, *args, **kwargs): Args: raw_data (pd.DataFrame): Raw data """ + self.bool_columns = set() self.bool_columns = self.bool_columns.union( set(raw_data.infer_objects().select_dtypes(include=["bool"]).columns) ) diff --git a/sdgx/data_models/inspectors/const.py b/sdgx/data_models/inspectors/const.py new file mode 100644 index 00000000..8cfc75f7 --- /dev/null +++ b/sdgx/data_models/inspectors/const.py @@ -0,0 +1,70 @@ +from __future__ import annotations + +import copy +from typing import Any + +import pandas as pd + +from sdgx.data_models.inspectors.base import Inspector +from sdgx.data_models.inspectors.extension import hookimpl + + +class ConstInspector(Inspector): + """ + ConstInspector is a class designed to identify columns in a DataFrame that contain constant values. + It extends the base Inspector class and is used to fit the data and inspect it for constant columns. + + Attributes: + const_columns (set[str]): A set of column names that contain constant values. + const_values (dict[Any]): A dictionary mapping column names to their constant values. + _inspect_level (int): The inspection level for this inspector, set to 80. + """ + + const_columns: set[str] = set() + """ + A set of column names that contain constant values. This attribute is populated during the fit method by identifying columns in the DataFrame where all values are the same. + """ + + const_values: dict[Any] = {} + """ + A dictionary mapping column names to their constant values. This attribute is populated during the fit method by storing the unique value found in each constant column. + """ + + _inspect_level = 80 + """ + The inspection level for this inspector, set to 80. This attribute indicates the priority or depth of inspection that this inspector performs relative to other inspectors. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def fit(self, raw_data: pd.DataFrame, *args, **kwargs): + """ + Fit the inspector to the raw data. + + This method identifies columns in the DataFrame that contain constant values. It populates the `const_columns` set with the names of these columns and the `const_values` dictionary with the constant values found in each column. + + Args: + raw_data (pd.DataFrame): The raw data to be inspected. + + Returns: + None + """ + self.const_columns = set() + # iterate each column + for column in raw_data.columns: + if len(raw_data[column].value_counts(normalize=True)) == 1: + self.const_columns.add(column) + # self.const_values[column] = raw_data[column][0] + + self.ready = True + + def inspect(self, *args, **kwargs) -> dict[str, Any]: + """Inspect raw data and generate metadata.""" + + return {"const_columns": self.const_columns} + + +@hookimpl +def register(manager): + manager.register("ConstInspector", ConstInspector) diff --git a/sdgx/data_models/inspectors/datetime.py b/sdgx/data_models/inspectors/datetime.py index 84c8ee03..cecc5bb5 100644 --- a/sdgx/data_models/inspectors/datetime.py +++ b/sdgx/data_models/inspectors/datetime.py @@ -62,6 +62,8 @@ def fit(self, raw_data: pd.DataFrame, *args, **kwargs): Args: raw_data (pd.DataFrame): Raw data """ + self.datetime_columns = set() + self.datetime_columns = self.datetime_columns.union( set(raw_data.infer_objects().select_dtypes(include=["datetime64"]).columns) ) diff --git a/sdgx/data_models/inspectors/discrete.py b/sdgx/data_models/inspectors/discrete.py index 8f7e229c..d38923d0 100644 --- a/sdgx/data_models/inspectors/discrete.py +++ b/sdgx/data_models/inspectors/discrete.py @@ -21,6 +21,7 @@ def fit(self, raw_data: pd.DataFrame, *args, **kwargs): Args: raw_data (pd.DataFrame): Raw data """ + self.discrete_columns = set() self.discrete_columns = self.discrete_columns.union( set(raw_data.select_dtypes(include="object").columns) diff --git a/sdgx/data_models/inspectors/i_id.py b/sdgx/data_models/inspectors/i_id.py index 7ee91065..b21c8daa 100644 --- a/sdgx/data_models/inspectors/i_id.py +++ b/sdgx/data_models/inspectors/i_id.py @@ -29,6 +29,8 @@ def fit(self, raw_data: pd.DataFrame, *args, **kwargs): raw_data (pd.DataFrame): Raw data """ + self.ID_columns = set() + df_length = len(raw_data) candidate_columns = set(raw_data.select_dtypes(include=["object", "int64"]).columns) diff --git a/sdgx/data_models/inspectors/numeric.py b/sdgx/data_models/inspectors/numeric.py index 2147589d..d71661d4 100644 --- a/sdgx/data_models/inspectors/numeric.py +++ b/sdgx/data_models/inspectors/numeric.py @@ -87,6 +87,9 @@ def fit(self, raw_data: pd.DataFrame, *args, **kwargs): raw_data (pd.DataFrame): Raw data """ + self.int_columns = set() + self.float_columns = set() + self.df_length = len(raw_data) float_candidate = self.float_columns.union( diff --git a/sdgx/data_models/metadata.py b/sdgx/data_models/metadata.py index f1df8d9f..a3ca264b 100644 --- a/sdgx/data_models/metadata.py +++ b/sdgx/data_models/metadata.py @@ -71,6 +71,7 @@ def check_column_list(cls, value) -> Any: bool_columns: Set[str] = set() discrete_columns: Set[str] = set() datetime_columns: Set[str] = set() + const_columns: Set[str] = set() datetime_format: Dict = defaultdict(str) # version info @@ -298,6 +299,9 @@ def from_dataloader( inspectors = im.init_inspcetors( include_inspectors, exclude_inspectors, **(inspector_init_kwargs or {}) ) + # set all inspectors not ready + for inspector in inspectors: + inspector.ready = False for i, chunk in enumerate(dataloader.iter()): for inspector in inspectors: if not inspector.ready: diff --git a/sdgx/data_processors/manager.py b/sdgx/data_processors/manager.py index 226df258..d995babe 100644 --- a/sdgx/data_processors/manager.py +++ b/sdgx/data_processors/manager.py @@ -53,7 +53,11 @@ class DataProcessorManager(Manager): "IntValueFormatter", "DatetimeFormatter", ] - ] + ["EmptyTransformer".lower(), "ColumnOrderTransformer".lower()] + ] + [ + "ConstValueTransformer".lower(), + "EmptyTransformer".lower(), + "ColumnOrderTransformer".lower(), + ] """ preset_defalut_processors list stores the lowercase names of the transformers loaded by default. When using the synthesizer, they will be loaded by default to facilitate user operations. diff --git a/sdgx/data_processors/transformers/const.py b/sdgx/data_processors/transformers/const.py new file mode 100644 index 00000000..4af042d8 --- /dev/null +++ b/sdgx/data_processors/transformers/const.py @@ -0,0 +1,110 @@ +from __future__ import annotations + +import copy +from typing import Any + +import pandas as pd + +from sdgx.data_models.metadata import Metadata +from sdgx.data_processors.extension import hookimpl +from sdgx.data_processors.transformers.base import Transformer +from sdgx.utils import logger + + +class ConstValueTransformer(Transformer): + """ + A transformer that replaces the input with a constant value. + + This class is used to transform any input data into a predefined constant value. + It is particularly useful in scenarios where a consistent output is required regardless of the input. + + Attributes: + const_value (dict[Any]): The constant value that will be returned. + """ + + const_columns: list = [] + + const_values: dict[Any] = {} + + def fit(self, metadata: Metadata | None = None, **kwargs: dict[str, Any]): + """ + Fit method for the transformer. + + This method processes the metadata to identify columns that should be replaced with a constant value. + It updates the internal state of the transformer with the columns and their corresponding constant values. + + Args: + metadata (Metadata | None): The metadata object containing information about the columns and their data types. + **kwargs (dict[str, Any]): Additional keyword arguments. + + Returns: + None + """ + + for each_col in metadata.column_list: + if metadata.get_column_data_type(each_col) == "const": + self.const_columns.append(each_col) + # self.const_values[each_col] = metadata.get("const_values")[each_col] + + logger.info("ConstValueTransformer Fitted.") + + self.fitted = True + + def convert(self, raw_data: pd.DataFrame) -> pd.DataFrame: + """ + Convert method to handle missing values in the input data by replacing specified columns with constant values. + + This method iterates over the columns identified for replacement with constant values and removes them from the input DataFrame. + The removal is based on the columns specified during the fitting process. + + Args: + raw_data (pd.DataFrame): The input DataFrame containing the data to be processed. + + Returns: + pd.DataFrame: A DataFrame with the specified columns removed. + """ + + processed_data = copy.deepcopy(raw_data) + + logger.info("Converting data using ConstValueTransformer...") + + for each_col in self.const_columns: + # record values here + if each_col not in self.const_values.keys(): + self.const_values[each_col] = processed_data[each_col].unique()[0] + processed_data = self.remove_columns(processed_data, [each_col]) + + logger.info("Converting data using ConstValueTransformer... Finished.") + + return processed_data + + def reverse_convert(self, processed_data: pd.DataFrame) -> pd.DataFrame: + """ + Reverse_convert method for the transformer. + + This method restores the original columns that were replaced with constant values during the conversion process. + It iterates over the columns identified for replacement with constant values and adds them back to the DataFrame + with the predefined constant values. + + Args: + processed_data (pd.DataFrame): The input DataFrame containing the processed data. + + Returns: + pd.DataFrame: A DataFrame with the original columns restored, filled with their corresponding constant values. + """ + df_length = processed_data.shape[0] + + for each_col_name in self.const_columns: + each_value = self.const_values[each_col_name] + each_const_col = [each_value for _ in range(df_length)] + each_const_df = pd.DataFrame({each_col_name: each_const_col}) + processed_data = self.attach_columns(processed_data, each_const_df) + + logger.info("Data reverse-converted by ConstValueTransformer.") + + return processed_data + + +@hookimpl +def register(manager): + manager.register("ConstValueTransformer", ConstValueTransformer) diff --git a/tests/data_models/inspector/test_const.py b/tests/data_models/inspector/test_const.py new file mode 100644 index 00000000..49fe0f62 --- /dev/null +++ b/tests/data_models/inspector/test_const.py @@ -0,0 +1,36 @@ +import copy + +import pandas as pd +import pytest + +from sdgx.data_models.inspectors.const import ConstInspector + + +@pytest.fixture +def test_const_data(demo_single_table_path): + const_col_df = pd.read_csv(demo_single_table_path) + + # Convert the columns to float to allow None values + const_col_df["age"] = const_col_df["age"].astype(float) + const_col_df["fnlwgt"] = const_col_df["fnlwgt"].astype(float) + + # Set the values to None + const_col_df["age"].values[:] = 100 + const_col_df["fnlwgt"].values[:] = 3.14 + const_col_df["workclass"].values[:] = "President" + + yield const_col_df + + +def test_const_inspector(test_const_data: pd.DataFrame): + inspector = ConstInspector() + inspector.fit(test_const_data) + assert inspector.ready + assert inspector.const_columns + + assert sorted(inspector.inspect()["const_columns"]) == sorted(["age", "fnlwgt", "workclass"]) + assert inspector.inspect_level == 80 + + +if __name__ == "__main__": + pytest.main(["-vv", "-s", __file__]) diff --git a/tests/data_models/test_metadata.py b/tests/data_models/test_metadata.py index 379f75e8..db8421b0 100644 --- a/tests/data_models/test_metadata.py +++ b/tests/data_models/test_metadata.py @@ -90,7 +90,7 @@ def test_demo_multi_table_data_metadata_child(demo_multi_data_child_matadata): assert demo_multi_data_child_matadata.get_column_data_type("Store") == "int" assert demo_multi_data_child_matadata.get_column_data_type("Date") == "datetime" assert demo_multi_data_child_matadata.get_column_data_type("Customers") == "int" - assert demo_multi_data_child_matadata.get_column_data_type("StateHoliday") == "int" + assert demo_multi_data_child_matadata.get_column_data_type("StateHoliday") == "const" assert demo_multi_data_child_matadata.get_column_data_type("Sales") == "int" assert demo_multi_data_child_matadata.get_column_data_type("Promo") == "int" assert demo_multi_data_child_matadata.get_column_data_type("DayOfWeek") == "int" diff --git a/tests/data_processors/transformers/test_transformers_const.py b/tests/data_processors/transformers/test_transformers_const.py new file mode 100644 index 00000000..585b4056 --- /dev/null +++ b/tests/data_processors/transformers/test_transformers_const.py @@ -0,0 +1,83 @@ +import copy + +import numpy as np +import pandas as pd +import pytest + +from sdgx.data_models.metadata import Metadata +from sdgx.data_processors.transformers.const import ConstValueTransformer + + +@pytest.fixture +def test_const_data(demo_single_table_path): + + const_col_df = pd.read_csv(demo_single_table_path) + # Convert the columns to float to allow None values + const_col_df["age"] = const_col_df["age"].astype(float) + const_col_df["fnlwgt"] = const_col_df["fnlwgt"].astype(float) + + # Set the values to None + const_col_df["age"].values[:] = 100 + const_col_df["fnlwgt"].values[:] = 1.41421 + const_col_df["workclass"].values[:] = "President" + + yield const_col_df + + +def test_const_handling_test_df(test_const_data: pd.DataFrame): + """ + Test the handling of const columns in a DataFrame. + This function tests the behavior of a DataFrame when it contains const columns. + It is designed to be used in a testing environment, where the DataFrame is passed as an argument. + + Parameters: + test_const_data (pd.DataFrame): The DataFrame to test. + + Returns: + None + + Raises: + AssertionError: If the DataFrame does not handle const columns as expected. + """ + + metadata = Metadata.from_dataframe(test_const_data) + + # Initialize the ConstValueTransformer. + const_transformer = ConstValueTransformer() + # Check if the transformer has not been fitted yet. + assert const_transformer.fitted is False + + # Fit the transformer with the DataFrame. + const_transformer.fit(metadata) + + # Check if the transformer has been fitted after the fit operation. + assert const_transformer.fitted + + # Check the const column + assert sorted(const_transformer.const_columns) == [ + "age", + "fnlwgt", + "workclass", + ] + + # Transform the DataFrame using the transformer. + transformed_df = const_transformer.convert(test_const_data) + + assert "age" not in transformed_df.columns + assert "fnlwgt" not in transformed_df.columns + assert "workclass" not in transformed_df.columns + + # reverse convert the df + reverse_converted_df = const_transformer.reverse_convert(transformed_df) + + assert "age" in reverse_converted_df.columns + assert "fnlwgt" in reverse_converted_df.columns + assert "workclass" in reverse_converted_df.columns + + assert reverse_converted_df["age"][0] == 100 + assert reverse_converted_df["fnlwgt"][0] == 1.41421 + assert reverse_converted_df["workclass"][0] == "President" + + assert len(reverse_converted_df["age"].unique()) == 1 + assert len(reverse_converted_df["fnlwgt"].unique()) == 1 + assert len(reverse_converted_df["workclass"].unique()) == 1