Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ConstInspector and ConstValueTransformer for Handling Constant Columns #202

Merged
merged 32 commits into from
Jul 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
bfe5c74
add value_fields in metadata
MooooCat Jul 12, 2024
f8cf69a
add ConstInspector
MooooCat Jul 12, 2024
159b708
add ConstValueTransformer and its testcase
MooooCat Jul 12, 2024
6f6d87a
update comments and validator parameter name
MooooCat Jul 12, 2024
24dfaab
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 12, 2024
b085b2a
Merge branch 'main' into feature-handle-const-columns
MooooCat Jul 15, 2024
67c5a6e
update typo in test case
MooooCat Jul 15, 2024
01f156e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 15, 2024
085dfb2
fix metadata test error of value fields
MooooCat Jul 15, 2024
f0472e9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 15, 2024
29b7bfa
add const_columns in default
MooooCat Jul 15, 2024
c23f632
Refreshing the test cases
MooooCat Jul 15, 2024
d0c5733
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 15, 2024
029d5ba
Revise the unit tests for the module handling two constant collection…
MooooCat Jul 15, 2024
63d0b95
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 15, 2024
9324b45
add deepcopy in const.py
MooooCat Jul 15, 2024
fb1d3b8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 15, 2024
2fa14eb
restore metadata, modify const inspector and transformer
MooooCat Jul 18, 2024
7d0d5c7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 18, 2024
327bee1
remove const_values in inspector's unit test
MooooCat Jul 18, 2024
bd23517
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 18, 2024
dbad60f
remove multiple metadata in unit test
MooooCat Jul 18, 2024
9357e5b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 18, 2024
23dd461
set all inspectors not ready before fit chunk
MooooCat Jul 18, 2024
18d6797
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 18, 2024
966641b
try reset the const columns after inspect
MooooCat Jul 19, 2024
accf17f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 19, 2024
ed13d57
Merge branch 'main' into feature-handle-const-columns
MooooCat Jul 31, 2024
afa0dbd
clear column set before inspector fit
MooooCat Jul 31, 2024
f32311e
change test func name
MooooCat Jul 31, 2024
14097f4
add const type in test case
MooooCat Jul 31, 2024
fafa39f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 31, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sdgx/data_models/inspectors/bool.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def fit(self, raw_data: pd.DataFrame, *args, **kwargs):
Args:
raw_data (pd.DataFrame): Raw data
"""
self.bool_columns = set()
self.bool_columns = self.bool_columns.union(
set(raw_data.infer_objects().select_dtypes(include=["bool"]).columns)
)
Expand Down
70 changes: 70 additions & 0 deletions sdgx/data_models/inspectors/const.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from __future__ import annotations

import copy
from typing import Any

import pandas as pd

from sdgx.data_models.inspectors.base import Inspector
from sdgx.data_models.inspectors.extension import hookimpl


class ConstInspector(Inspector):
"""
ConstInspector is a class designed to identify columns in a DataFrame that contain constant values.
It extends the base Inspector class and is used to fit the data and inspect it for constant columns.

Attributes:
const_columns (set[str]): A set of column names that contain constant values.
const_values (dict[Any]): A dictionary mapping column names to their constant values.
_inspect_level (int): The inspection level for this inspector, set to 80.
"""

const_columns: set[str] = set()
"""
A set of column names that contain constant values. This attribute is populated during the fit method by identifying columns in the DataFrame where all values are the same.
"""

const_values: dict[Any] = {}
"""
A dictionary mapping column names to their constant values. This attribute is populated during the fit method by storing the unique value found in each constant column.
"""

_inspect_level = 80
"""
The inspection level for this inspector, set to 80. This attribute indicates the priority or depth of inspection that this inspector performs relative to other inspectors.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def fit(self, raw_data: pd.DataFrame, *args, **kwargs):
"""
Fit the inspector to the raw data.

This method identifies columns in the DataFrame that contain constant values. It populates the `const_columns` set with the names of these columns and the `const_values` dictionary with the constant values found in each column.

Args:
raw_data (pd.DataFrame): The raw data to be inspected.

Returns:
None
"""
self.const_columns = set()
# iterate each column
for column in raw_data.columns:
if len(raw_data[column].value_counts(normalize=True)) == 1:
self.const_columns.add(column)
# self.const_values[column] = raw_data[column][0]

self.ready = True

def inspect(self, *args, **kwargs) -> dict[str, Any]:
"""Inspect raw data and generate metadata."""

return {"const_columns": self.const_columns}


@hookimpl
def register(manager):
manager.register("ConstInspector", ConstInspector)
2 changes: 2 additions & 0 deletions sdgx/data_models/inspectors/datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ def fit(self, raw_data: pd.DataFrame, *args, **kwargs):
Args:
raw_data (pd.DataFrame): Raw data
"""
self.datetime_columns = set()

self.datetime_columns = self.datetime_columns.union(
set(raw_data.infer_objects().select_dtypes(include=["datetime64"]).columns)
)
Expand Down
1 change: 1 addition & 0 deletions sdgx/data_models/inspectors/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def fit(self, raw_data: pd.DataFrame, *args, **kwargs):
Args:
raw_data (pd.DataFrame): Raw data
"""
self.discrete_columns = set()

self.discrete_columns = self.discrete_columns.union(
set(raw_data.select_dtypes(include="object").columns)
Expand Down
2 changes: 2 additions & 0 deletions sdgx/data_models/inspectors/i_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ def fit(self, raw_data: pd.DataFrame, *args, **kwargs):
raw_data (pd.DataFrame): Raw data
"""

self.ID_columns = set()

df_length = len(raw_data)
candidate_columns = set(raw_data.select_dtypes(include=["object", "int64"]).columns)

Expand Down
3 changes: 3 additions & 0 deletions sdgx/data_models/inspectors/numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ def fit(self, raw_data: pd.DataFrame, *args, **kwargs):
raw_data (pd.DataFrame): Raw data
"""

self.int_columns = set()
self.float_columns = set()

self.df_length = len(raw_data)

float_candidate = self.float_columns.union(
Expand Down
4 changes: 4 additions & 0 deletions sdgx/data_models/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def check_column_list(cls, value) -> Any:
bool_columns: Set[str] = set()
discrete_columns: Set[str] = set()
datetime_columns: Set[str] = set()
const_columns: Set[str] = set()
datetime_format: Dict = defaultdict(str)

# version info
Expand Down Expand Up @@ -298,6 +299,9 @@ def from_dataloader(
inspectors = im.init_inspcetors(
include_inspectors, exclude_inspectors, **(inspector_init_kwargs or {})
)
# set all inspectors not ready
for inspector in inspectors:
inspector.ready = False
for i, chunk in enumerate(dataloader.iter()):
for inspector in inspectors:
if not inspector.ready:
Expand Down
6 changes: 5 additions & 1 deletion sdgx/data_processors/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,11 @@ class DataProcessorManager(Manager):
"IntValueFormatter",
"DatetimeFormatter",
]
] + ["EmptyTransformer".lower(), "ColumnOrderTransformer".lower()]
] + [
"ConstValueTransformer".lower(),
"EmptyTransformer".lower(),
"ColumnOrderTransformer".lower(),
]
"""
preset_defalut_processors list stores the lowercase names of the transformers loaded by default. When using the synthesizer, they will be loaded by default to facilitate user operations.

Expand Down
110 changes: 110 additions & 0 deletions sdgx/data_processors/transformers/const.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
from __future__ import annotations

import copy
from typing import Any

import pandas as pd

from sdgx.data_models.metadata import Metadata
from sdgx.data_processors.extension import hookimpl
from sdgx.data_processors.transformers.base import Transformer
from sdgx.utils import logger


class ConstValueTransformer(Transformer):
"""
A transformer that replaces the input with a constant value.

This class is used to transform any input data into a predefined constant value.
It is particularly useful in scenarios where a consistent output is required regardless of the input.

Attributes:
const_value (dict[Any]): The constant value that will be returned.
"""

const_columns: list = []

const_values: dict[Any] = {}

def fit(self, metadata: Metadata | None = None, **kwargs: dict[str, Any]):
"""
Fit method for the transformer.

This method processes the metadata to identify columns that should be replaced with a constant value.
It updates the internal state of the transformer with the columns and their corresponding constant values.

Args:
metadata (Metadata | None): The metadata object containing information about the columns and their data types.
**kwargs (dict[str, Any]): Additional keyword arguments.

Returns:
None
"""

for each_col in metadata.column_list:
if metadata.get_column_data_type(each_col) == "const":
self.const_columns.append(each_col)
# self.const_values[each_col] = metadata.get("const_values")[each_col]

logger.info("ConstValueTransformer Fitted.")

self.fitted = True

def convert(self, raw_data: pd.DataFrame) -> pd.DataFrame:
"""
Convert method to handle missing values in the input data by replacing specified columns with constant values.

This method iterates over the columns identified for replacement with constant values and removes them from the input DataFrame.
The removal is based on the columns specified during the fitting process.

Args:
raw_data (pd.DataFrame): The input DataFrame containing the data to be processed.

Returns:
pd.DataFrame: A DataFrame with the specified columns removed.
"""

processed_data = copy.deepcopy(raw_data)

logger.info("Converting data using ConstValueTransformer...")

for each_col in self.const_columns:
# record values here
if each_col not in self.const_values.keys():
self.const_values[each_col] = processed_data[each_col].unique()[0]
processed_data = self.remove_columns(processed_data, [each_col])

logger.info("Converting data using ConstValueTransformer... Finished.")

return processed_data

def reverse_convert(self, processed_data: pd.DataFrame) -> pd.DataFrame:
"""
Reverse_convert method for the transformer.

This method restores the original columns that were replaced with constant values during the conversion process.
It iterates over the columns identified for replacement with constant values and adds them back to the DataFrame
with the predefined constant values.

Args:
processed_data (pd.DataFrame): The input DataFrame containing the processed data.

Returns:
pd.DataFrame: A DataFrame with the original columns restored, filled with their corresponding constant values.
"""
df_length = processed_data.shape[0]

for each_col_name in self.const_columns:
each_value = self.const_values[each_col_name]
each_const_col = [each_value for _ in range(df_length)]
each_const_df = pd.DataFrame({each_col_name: each_const_col})
processed_data = self.attach_columns(processed_data, each_const_df)

logger.info("Data reverse-converted by ConstValueTransformer.")

return processed_data


@hookimpl
def register(manager):
manager.register("ConstValueTransformer", ConstValueTransformer)
36 changes: 36 additions & 0 deletions tests/data_models/inspector/test_const.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import copy

import pandas as pd
import pytest

from sdgx.data_models.inspectors.const import ConstInspector


@pytest.fixture
def test_const_data(demo_single_table_path):
const_col_df = pd.read_csv(demo_single_table_path)

# Convert the columns to float to allow None values
const_col_df["age"] = const_col_df["age"].astype(float)
const_col_df["fnlwgt"] = const_col_df["fnlwgt"].astype(float)

# Set the values to None
const_col_df["age"].values[:] = 100
const_col_df["fnlwgt"].values[:] = 3.14
const_col_df["workclass"].values[:] = "President"

yield const_col_df


def test_const_inspector(test_const_data: pd.DataFrame):
inspector = ConstInspector()
inspector.fit(test_const_data)
assert inspector.ready
assert inspector.const_columns

assert sorted(inspector.inspect()["const_columns"]) == sorted(["age", "fnlwgt", "workclass"])
assert inspector.inspect_level == 80


if __name__ == "__main__":
pytest.main(["-vv", "-s", __file__])
2 changes: 1 addition & 1 deletion tests/data_models/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def test_demo_multi_table_data_metadata_child(demo_multi_data_child_matadata):
assert demo_multi_data_child_matadata.get_column_data_type("Store") == "int"
assert demo_multi_data_child_matadata.get_column_data_type("Date") == "datetime"
assert demo_multi_data_child_matadata.get_column_data_type("Customers") == "int"
assert demo_multi_data_child_matadata.get_column_data_type("StateHoliday") == "int"
assert demo_multi_data_child_matadata.get_column_data_type("StateHoliday") == "const"
assert demo_multi_data_child_matadata.get_column_data_type("Sales") == "int"
assert demo_multi_data_child_matadata.get_column_data_type("Promo") == "int"
assert demo_multi_data_child_matadata.get_column_data_type("DayOfWeek") == "int"
Expand Down
83 changes: 83 additions & 0 deletions tests/data_processors/transformers/test_transformers_const.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import copy

import numpy as np
import pandas as pd
import pytest

from sdgx.data_models.metadata import Metadata
from sdgx.data_processors.transformers.const import ConstValueTransformer


@pytest.fixture
def test_const_data(demo_single_table_path):

const_col_df = pd.read_csv(demo_single_table_path)
# Convert the columns to float to allow None values
const_col_df["age"] = const_col_df["age"].astype(float)
const_col_df["fnlwgt"] = const_col_df["fnlwgt"].astype(float)

# Set the values to None
const_col_df["age"].values[:] = 100
const_col_df["fnlwgt"].values[:] = 1.41421
const_col_df["workclass"].values[:] = "President"

yield const_col_df


def test_const_handling_test_df(test_const_data: pd.DataFrame):
"""
Test the handling of const columns in a DataFrame.
This function tests the behavior of a DataFrame when it contains const columns.
It is designed to be used in a testing environment, where the DataFrame is passed as an argument.

Parameters:
test_const_data (pd.DataFrame): The DataFrame to test.

Returns:
None

Raises:
AssertionError: If the DataFrame does not handle const columns as expected.
"""

metadata = Metadata.from_dataframe(test_const_data)

# Initialize the ConstValueTransformer.
const_transformer = ConstValueTransformer()
# Check if the transformer has not been fitted yet.
assert const_transformer.fitted is False

# Fit the transformer with the DataFrame.
const_transformer.fit(metadata)

# Check if the transformer has been fitted after the fit operation.
assert const_transformer.fitted

# Check the const column
assert sorted(const_transformer.const_columns) == [
"age",
"fnlwgt",
"workclass",
]

# Transform the DataFrame using the transformer.
transformed_df = const_transformer.convert(test_const_data)

assert "age" not in transformed_df.columns
assert "fnlwgt" not in transformed_df.columns
assert "workclass" not in transformed_df.columns

# reverse convert the df
reverse_converted_df = const_transformer.reverse_convert(transformed_df)

assert "age" in reverse_converted_df.columns
assert "fnlwgt" in reverse_converted_df.columns
assert "workclass" in reverse_converted_df.columns

assert reverse_converted_df["age"][0] == 100
assert reverse_converted_df["fnlwgt"][0] == 1.41421
assert reverse_converted_df["workclass"][0] == "President"

assert len(reverse_converted_df["age"].unique()) == 1
assert len(reverse_converted_df["fnlwgt"].unique()) == 1
assert len(reverse_converted_df["workclass"].unique()) == 1
Loading