Skip to content

Commit

Permalink
feat: add bigframes.ml.compose.SQLScalarColumnTransformer to create c…
Browse files Browse the repository at this point in the history
…ustom SQL-based transformations (#955)

* Add support for custom transformers (not ML.) in ColumnTransformer.

* allow numbers in Custom-Transformer-IDs.

* comment was moved to the end of the sql.

* Do not offer the feedback link for missing custom transformers.

* cleanup typing hints.

* Add unit tests for CustomTransformer.

* added unit tests for _extract_output_names() and _compile_to_sql().

* run black and flake8 linter.

* fixed wrong @classmethod annotation.

* on the way to SQLScalarColumnTransformer

* remove pytest.main call.

* remove CustomTransformer class and implementations.

* fix typing.

* fix typing.

* fixed mock typing.

* replace _NameClass.

* black formating.

* add traget_column as input_column with a "?" prefix
when parsing SQLScalarColumnTransformer from sql.

* reformatted with black version 22.3.0.

* 🦉 Updates from OwlBot post-processor

See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md

* remove eclipse project files

* SQLScalarColumnTransformer needs not to be inherited from
base.BaseTransformer.

* remove filter for "ML." sqls in _extract_output_names() of
BaseTransformer

* introduced type hint SingleColTransformer
for transformers contained in ColumnTransformer

* make sql and target_column private in SQLScalarColumnTransformer

* Add documentation for SQLScalarColumnTransformer.

* add first system test for SQLScalarColumnTransformer.

* SQLScalarColumnTransformer system tests for fit-transform and save-load

* make SQLScalarColumnTransformer comparable (equals) for comparing sets
in tests

* implement hash and eq (copied from BaseTransformer)

* undo accidentally checked in files

* remove eclipse settings accidentally checked in.

* fix docs.

* Update bigframes/ml/compose.py

* Update bigframes/ml/compose.py

* add support for flexible column names.

* remove main.

* add system test for output column with flexible column name

* system tests: add new flexible output column to check-df-schema.

* Apply suggestions from code review

---------

Co-authored-by: Ferenc Hechler <[email protected]>
Co-authored-by: Tim Sweña (Swast) <[email protected]>
Co-authored-by: Owl Bot <gcf-owl-bot[bot]@users.noreply.github.com>
  • Loading branch information
4 people authored Sep 25, 2024
1 parent 3c54399 commit 1930b4e
Show file tree
Hide file tree
Showing 4 changed files with 633 additions and 14 deletions.
4 changes: 0 additions & 4 deletions bigframes/ml/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,10 +198,6 @@ def _extract_output_names(self):
# pass the columns that are not transformed
if "transformSql" not in transform_col_dict:
continue
transform_sql: str = transform_col_dict["transformSql"]
if not transform_sql.startswith("ML."):
continue

output_names.append(transform_col_dict["name"])

self._output_names = output_names
Expand Down
134 changes: 124 additions & 10 deletions bigframes/ml/compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,101 @@
)


class SQLScalarColumnTransformer:
r"""
Wrapper for plain SQL code contained in a ColumnTransformer.
Create a single column transformer in plain sql.
This transformer can only be used inside ColumnTransformer.
When creating an instance '{0}' can be used as placeholder
for the column to transform:
SQLScalarColumnTransformer("{0}+1")
The default target column gets the prefix 'transformed\_'
but can also be changed when creating an instance:
SQLScalarColumnTransformer("{0}+1", "inc_{0}")
**Examples:**
>>> from bigframes.ml.compose import ColumnTransformer, SQLScalarColumnTransformer
>>> import bigframes.pandas as bpd
<BLANKLINE>
>>> df = bpd.DataFrame({'name': ["James", None, "Mary"], 'city': ["New York", "Boston", None]})
>>> col_trans = ColumnTransformer([
... ("strlen",
... SQLScalarColumnTransformer("CASE WHEN {0} IS NULL THEN 15 ELSE LENGTH({0}) END"),
... ['name', 'city']),
... ])
>>> col_trans = col_trans.fit(df)
>>> df_transformed = col_trans.transform(df)
>>> df_transformed
transformed_name transformed_city
0 5 8
1 15 6
2 4 15
<BLANKLINE>
[3 rows x 2 columns]
SQLScalarColumnTransformer can be combined with other transformers, like StandardScaler:
>>> col_trans = ColumnTransformer([
... ("identity", SQLScalarColumnTransformer("{0}", target_column="{0}"), ["col1", "col5"]),
... ("increment", SQLScalarColumnTransformer("{0}+1", target_column="inc_{0}"), "col2"),
... ("stdscale", preprocessing.StandardScaler(), "col3"),
... # ...
... ])
"""

def __init__(self, sql: str, target_column: str = "transformed_{0}"):
super().__init__()
self._sql = sql
self._target_column = target_column.replace("`", "")

PLAIN_COLNAME_RX = re.compile("^[a-z][a-z0-9_]*$", re.IGNORECASE)

def escape(self, colname: str):
colname = colname.replace("`", "")
if self.PLAIN_COLNAME_RX.match(colname):
return colname
return f"`{colname}`"

def _compile_to_sql(
self, X: bpd.DataFrame, columns: Optional[Iterable[str]] = None
) -> List[str]:
if columns is None:
columns = X.columns
result = []
for column in columns:
current_sql = self._sql.format(self.escape(column))
current_target_column = self.escape(self._target_column.format(column))
result.append(f"{current_sql} AS {current_target_column}")
return result

def __repr__(self):
return f"SQLScalarColumnTransformer(sql='{self._sql}', target_column='{self._target_column}')"

def __eq__(self, other) -> bool:
return type(self) is type(other) and self._keys() == other._keys()

def __hash__(self) -> int:
return hash(self._keys())

def _keys(self):
return (self._sql, self._target_column)


# Type hints for transformers contained in ColumnTransformer
SingleColTransformer = Union[
preprocessing.PreprocessingType,
impute.SimpleImputer,
SQLScalarColumnTransformer,
]


@log_adapter.class_logger
class ColumnTransformer(
base.Transformer,
Expand All @@ -60,7 +155,7 @@ def __init__(
transformers: Iterable[
Tuple[
str,
Union[preprocessing.PreprocessingType, impute.SimpleImputer],
SingleColTransformer,
Union[str, Iterable[str]],
]
],
Expand All @@ -78,14 +173,12 @@ def _keys(self):
@property
def transformers_(
self,
) -> List[
Tuple[str, Union[preprocessing.PreprocessingType, impute.SimpleImputer], str]
]:
) -> List[Tuple[str, SingleColTransformer, str,]]:
"""The collection of transformers as tuples of (name, transformer, column)."""
result: List[
Tuple[
str,
Union[preprocessing.PreprocessingType, impute.SimpleImputer],
SingleColTransformer,
str,
]
] = []
Expand All @@ -103,6 +196,8 @@ def transformers_(

return result

AS_FLEXNAME_SUFFIX_RX = re.compile("^(.*)\\bAS\\s*`[^`]+`\\s*$", re.IGNORECASE)

@classmethod
def _extract_from_bq_model(
cls,
Expand All @@ -114,7 +209,7 @@ def _extract_from_bq_model(
transformers_set: Set[
Tuple[
str,
Union[preprocessing.PreprocessingType, impute.SimpleImputer],
SingleColTransformer,
Union[str, List[str]],
]
] = set()
Expand All @@ -130,8 +225,11 @@ def camel_to_snake(name):
if "transformSql" not in transform_col_dict:
continue
transform_sql: str = transform_col_dict["transformSql"]
if not transform_sql.startswith("ML."):
continue

# workaround for bug in bq_model returning " AS `...`" suffix for flexible names
flex_name_match = cls.AS_FLEXNAME_SUFFIX_RX.match(transform_sql)
if flex_name_match:
transform_sql = flex_name_match.group(1)

output_names.append(transform_col_dict["name"])
found_transformer = False
Expand All @@ -148,8 +246,22 @@ def camel_to_snake(name):
found_transformer = True
break
if not found_transformer:
raise NotImplementedError(
f"Unsupported transformer type. {constants.FEEDBACK_LINK}"
if transform_sql.startswith("ML."):
raise NotImplementedError(
f"Unsupported transformer type. {constants.FEEDBACK_LINK}"
)

target_column = transform_col_dict["name"]
sql_transformer = SQLScalarColumnTransformer(
transform_sql, target_column=target_column
)
input_column_name = f"?{target_column}"
transformers_set.add(
(
camel_to_snake(sql_transformer.__class__.__name__),
sql_transformer,
input_column_name,
)
)

transformer = cls(transformers=list(transformers_set))
Expand All @@ -167,6 +279,8 @@ def _merge(

assert len(transformers) > 0
_, transformer_0, column_0 = transformers[0]
if isinstance(transformer_0, SQLScalarColumnTransformer):
return self # SQLScalarColumnTransformer only work inside ColumnTransformer
feature_columns_sorted = sorted(
[
cast(str, feature_column.name)
Expand Down
103 changes: 103 additions & 0 deletions tests/system/large/ml/test_compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,32 @@ def test_columntransformer_standalone_fit_and_transform(
preprocessing.MinMaxScaler(),
["culmen_length_mm"],
),
(
"increment",
compose.SQLScalarColumnTransformer("{0}+1"),
["culmen_length_mm", "flipper_length_mm"],
),
(
"length",
compose.SQLScalarColumnTransformer(
"CASE WHEN {0} IS NULL THEN -1 ELSE LENGTH({0}) END",
target_column="len_{0}",
),
"species",
),
(
"ohe",
compose.SQLScalarColumnTransformer(
"CASE WHEN {0}='Adelie Penguin (Pygoscelis adeliae)' THEN 1 ELSE 0 END",
target_column="ohe_adelie",
),
"species",
),
(
"identity",
compose.SQLScalarColumnTransformer("{0}", target_column="{0}"),
["culmen_length_mm", "flipper_length_mm"],
),
]
)

Expand All @@ -51,6 +77,12 @@ def test_columntransformer_standalone_fit_and_transform(
"standard_scaled_culmen_length_mm",
"min_max_scaled_culmen_length_mm",
"standard_scaled_flipper_length_mm",
"transformed_culmen_length_mm",
"transformed_flipper_length_mm",
"len_species",
"ohe_adelie",
"culmen_length_mm",
"flipper_length_mm",
],
index=[1633, 1672, 1690],
col_exact=False,
Expand All @@ -70,6 +102,19 @@ def test_columntransformer_standalone_fit_transform(new_penguins_df):
preprocessing.StandardScaler(),
["culmen_length_mm", "flipper_length_mm"],
),
(
"length",
compose.SQLScalarColumnTransformer(
"CASE WHEN {0} IS NULL THEN -1 ELSE LENGTH({0}) END",
target_column="len_{0}",
),
"species",
),
(
"identity",
compose.SQLScalarColumnTransformer("{0}", target_column="{0}"),
["culmen_length_mm", "flipper_length_mm"],
),
]
)

Expand All @@ -83,6 +128,9 @@ def test_columntransformer_standalone_fit_transform(new_penguins_df):
"onehotencoded_species",
"standard_scaled_culmen_length_mm",
"standard_scaled_flipper_length_mm",
"len_species",
"culmen_length_mm",
"flipper_length_mm",
],
index=[1633, 1672, 1690],
col_exact=False,
Expand All @@ -102,6 +150,27 @@ def test_columntransformer_save_load(new_penguins_df, dataset_id):
preprocessing.StandardScaler(),
["culmen_length_mm", "flipper_length_mm"],
),
(
"length",
compose.SQLScalarColumnTransformer(
"CASE WHEN {0} IS NULL THEN -1 ELSE LENGTH({0}) END",
target_column="len_{0}",
),
"species",
),
(
"identity",
compose.SQLScalarColumnTransformer("{0}", target_column="{0}"),
["culmen_length_mm", "flipper_length_mm"],
),
(
"flexname",
compose.SQLScalarColumnTransformer(
"CASE WHEN {0} IS NULL THEN -1 ELSE LENGTH({0}) END",
target_column="Flex {0} Name",
),
"species",
),
]
)
transformer.fit(
Expand All @@ -122,6 +191,36 @@ def test_columntransformer_save_load(new_penguins_df, dataset_id):
),
("standard_scaler", preprocessing.StandardScaler(), "culmen_length_mm"),
("standard_scaler", preprocessing.StandardScaler(), "flipper_length_mm"),
(
"sql_scalar_column_transformer",
compose.SQLScalarColumnTransformer(
"CASE WHEN species IS NULL THEN -1 ELSE LENGTH(species) END",
target_column="len_species",
),
"?len_species",
),
(
"sql_scalar_column_transformer",
compose.SQLScalarColumnTransformer(
"flipper_length_mm", target_column="flipper_length_mm"
),
"?flipper_length_mm",
),
(
"sql_scalar_column_transformer",
compose.SQLScalarColumnTransformer(
"culmen_length_mm", target_column="culmen_length_mm"
),
"?culmen_length_mm",
),
(
"sql_scalar_column_transformer",
compose.SQLScalarColumnTransformer(
"CASE WHEN species IS NULL THEN -1 ELSE LENGTH(species) END ",
target_column="Flex species Name",
),
"?Flex species Name",
),
]
assert set(reloaded_transformer.transformers) == set(expected)
assert reloaded_transformer._bqml_model is not None
Expand All @@ -136,6 +235,10 @@ def test_columntransformer_save_load(new_penguins_df, dataset_id):
"onehotencoded_species",
"standard_scaled_culmen_length_mm",
"standard_scaled_flipper_length_mm",
"len_species",
"culmen_length_mm",
"flipper_length_mm",
"Flex species Name",
],
index=[1633, 1672, 1690],
col_exact=False,
Expand Down
Loading

0 comments on commit 1930b4e

Please sign in to comment.