Skip to content

Commit

Permalink
fix: Escape ids more consistently in ml module (#1074)
Browse files Browse the repository at this point in the history
  • Loading branch information
TrevorBergeron authored Oct 11, 2024
1 parent 8d74269 commit 103e998
Show file tree
Hide file tree
Showing 11 changed files with 173 additions and 142 deletions.
4 changes: 2 additions & 2 deletions bigframes/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,9 @@ def label_to_identifier(label: typing.Hashable, strict: bool = False) -> str:
"""
# Column values will be loaded as null if the column name has spaces.
# https://github.com/googleapis/python-bigquery/issues/1566
identifier = str(label).replace(" ", "_")

identifier = str(label)
if strict:
identifier = str(label).replace(" ", "_")
identifier = re.sub(r"[^a-zA-Z0-9_]", "", identifier)
if not identifier:
identifier = "id"
Expand Down
17 changes: 8 additions & 9 deletions bigframes/ml/compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from google.cloud import bigquery

from bigframes.core import log_adapter
import bigframes.core.compile.googlesql as sql_utils
from bigframes.ml import base, core, globals, impute, preprocessing, utils
import bigframes.pandas as bpd

Expand Down Expand Up @@ -98,25 +99,22 @@ class SQLScalarColumnTransformer:
def __init__(self, sql: str, target_column: str = "transformed_{0}"):
super().__init__()
self._sql = sql
# TODO: More robust unescaping
self._target_column = target_column.replace("`", "")

PLAIN_COLNAME_RX = re.compile("^[a-z][a-z0-9_]*$", re.IGNORECASE)

def escape(self, colname: str):
colname = colname.replace("`", "")
if self.PLAIN_COLNAME_RX.match(colname):
return colname
return f"`{colname}`"

def _compile_to_sql(
self, X: bpd.DataFrame, columns: Optional[Iterable[str]] = None
) -> List[str]:
if columns is None:
columns = X.columns
result = []
for column in columns:
current_sql = self._sql.format(self.escape(column))
current_target_column = self.escape(self._target_column.format(column))
current_sql = self._sql.format(sql_utils.identifier(column))
current_target_column = sql_utils.identifier(
self._target_column.format(column)
)
result.append(f"{current_sql} AS {current_target_column}")
return result

Expand Down Expand Up @@ -239,6 +237,7 @@ def camel_to_snake(name):
transformers_set.add(
(
camel_to_snake(transformer_cls.__name__),
# TODO: This is very fragile, use real SQL parser
*transformer_cls._parse_from_sql(transform_sql), # type: ignore
)
)
Expand All @@ -253,7 +252,7 @@ def camel_to_snake(name):

target_column = transform_col_dict["name"]
sql_transformer = SQLScalarColumnTransformer(
transform_sql, target_column=target_column
transform_sql.strip(), target_column=target_column
)
input_column_name = f"?{target_column}"
transformers_set.add(
Expand Down
4 changes: 3 additions & 1 deletion bigframes/ml/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,10 @@ class BqmlModel(BaseBqml):
def __init__(self, session: bigframes.Session, model: bigquery.Model):
self._session = session
self._model = model
model_ref = self._model.reference
assert model_ref is not None
self._model_manipulation_sql_generator = ml_sql.ModelManipulationSqlGenerator(
self.model_name
model_ref
)

def _apply_ml_tvf(
Expand Down
10 changes: 9 additions & 1 deletion bigframes/ml/impute.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def _parse_from_sql(cls, sql: str) -> tuple[SimpleImputer, str]:
tuple(SimpleImputer, column_label)"""
s = sql[sql.find("(") + 1 : sql.find(")")]
col_label, strategy = s.split(", ")
return cls(strategy[1:-1]), col_label # type: ignore[arg-type]
return cls(strategy[1:-1]), _unescape_id(col_label) # type: ignore[arg-type]

def fit(
self,
Expand Down Expand Up @@ -110,3 +110,11 @@ def transform(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame:
bpd.DataFrame,
df[self._output_names],
)


def _unescape_id(id: str) -> str:
"""Very simple conversion to removed ` characters from ids.
A proper sql parser should be used instead.
"""
return id.removeprefix("`").removesuffix("`")
26 changes: 18 additions & 8 deletions bigframes/ml/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def _parse_from_sql(cls, sql: str) -> tuple[StandardScaler, str]:
Returns:
tuple(StandardScaler, column_label)"""
col_label = sql[sql.find("(") + 1 : sql.find(")")]
return cls(), col_label
return cls(), _unescape_id(col_label)

def fit(
self,
Expand Down Expand Up @@ -152,8 +152,9 @@ def _parse_from_sql(cls, sql: str) -> tuple[MaxAbsScaler, str]:
Returns:
tuple(MaxAbsScaler, column_label)"""
# TODO: Use real sql parser
col_label = sql[sql.find("(") + 1 : sql.find(")")]
return cls(), col_label
return cls(), _unescape_id(col_label)

def fit(
self,
Expand Down Expand Up @@ -229,8 +230,9 @@ def _parse_from_sql(cls, sql: str) -> tuple[MinMaxScaler, str]:
Returns:
tuple(MinMaxScaler, column_label)"""
# TODO: Use real sql parser
col_label = sql[sql.find("(") + 1 : sql.find(")")]
return cls(), col_label
return cls(), _unescape_id(col_label)

def fit(
self,
Expand Down Expand Up @@ -349,11 +351,11 @@ def _parse_from_sql(cls, sql: str) -> tuple[KBinsDiscretizer, str]:

if sql.startswith("ML.QUANTILE_BUCKETIZE"):
num_bins = s.split(",")[1]
return cls(int(num_bins), "quantile"), col_label
return cls(int(num_bins), "quantile"), _unescape_id(col_label)
else:
array_split_points = s[s.find("[") + 1 : s.find("]")]
n_bins = array_split_points.count(",") + 2
return cls(n_bins, "uniform"), col_label
return cls(n_bins, "uniform"), _unescape_id(col_label)

def fit(
self,
Expand Down Expand Up @@ -469,7 +471,7 @@ def _parse_from_sql(cls, sql: str) -> tuple[OneHotEncoder, str]:
max_categories = int(top_k) + 1
min_frequency = int(frequency_threshold)

return cls(drop, min_frequency, max_categories), col_label
return cls(drop, min_frequency, max_categories), _unescape_id(col_label)

def fit(
self,
Expand Down Expand Up @@ -578,7 +580,7 @@ def _parse_from_sql(cls, sql: str) -> tuple[LabelEncoder, str]:
max_categories = int(top_k) + 1
min_frequency = int(frequency_threshold)

return cls(min_frequency, max_categories), col_label
return cls(min_frequency, max_categories), _unescape_id(col_label)

def fit(
self,
Expand Down Expand Up @@ -661,7 +663,7 @@ def _parse_from_sql(cls, sql: str) -> tuple[PolynomialFeatures, tuple[str, ...]]
col_labels = sql[sql.find("STRUCT(") + 7 : sql.find(")")].split(",")
col_labels = [label.strip() for label in col_labels]
degree = int(sql[sql.rfind(",") + 1 : sql.rfind(")")])
return cls(degree), tuple(col_labels)
return cls(degree), tuple(map(_unescape_id, col_labels))

def fit(
self,
Expand Down Expand Up @@ -694,6 +696,14 @@ def transform(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame:
)


def _unescape_id(id: str) -> str:
"""Very simple conversion to removed ` characters from ids.
A proper sql parser should be used instead.
"""
return id.removeprefix("`").removesuffix("`")


PreprocessingType = Union[
OneHotEncoder,
StandardScaler,
Expand Down
Loading

0 comments on commit 103e998

Please sign in to comment.