From 103e99823d442a36b2aaa5113950b988f6d3ba1e Mon Sep 17 00:00:00 2001 From: TrevorBergeron Date: Thu, 10 Oct 2024 18:24:32 -0700 Subject: [PATCH] fix: Escape ids more consistently in ml module (#1074) --- bigframes/core/utils.py | 4 +- bigframes/ml/compose.py | 17 +++-- bigframes/ml/core.py | 4 +- bigframes/ml/impute.py | 10 ++- bigframes/ml/preprocessing.py | 26 +++++--- bigframes/ml/sql.py | 87 +++++++++++++----------- tests/system/large/ml/test_compose.py | 20 +++--- tests/unit/core/test_bf_utils.py | 6 +- tests/unit/ml/test_compose.py | 30 ++++----- tests/unit/ml/test_golden_sql.py | 16 ++--- tests/unit/ml/test_sql.py | 95 ++++++++++++++------------- 11 files changed, 173 insertions(+), 142 deletions(-) diff --git a/bigframes/core/utils.py b/bigframes/core/utils.py index 43c05c6c83..e684ac55a4 100644 --- a/bigframes/core/utils.py +++ b/bigframes/core/utils.py @@ -116,9 +116,9 @@ def label_to_identifier(label: typing.Hashable, strict: bool = False) -> str: """ # Column values will be loaded as null if the column name has spaces. # https://github.com/googleapis/python-bigquery/issues/1566 - identifier = str(label).replace(" ", "_") - + identifier = str(label) if strict: + identifier = str(label).replace(" ", "_") identifier = re.sub(r"[^a-zA-Z0-9_]", "", identifier) if not identifier: identifier = "id" diff --git a/bigframes/ml/compose.py b/bigframes/ml/compose.py index 08c9761cc3..14cf12014f 100644 --- a/bigframes/ml/compose.py +++ b/bigframes/ml/compose.py @@ -28,6 +28,7 @@ from google.cloud import bigquery from bigframes.core import log_adapter +import bigframes.core.compile.googlesql as sql_utils from bigframes.ml import base, core, globals, impute, preprocessing, utils import bigframes.pandas as bpd @@ -98,16 +99,11 @@ class SQLScalarColumnTransformer: def __init__(self, sql: str, target_column: str = "transformed_{0}"): super().__init__() self._sql = sql + # TODO: More robust unescaping self._target_column = target_column.replace("`", "") PLAIN_COLNAME_RX = re.compile("^[a-z][a-z0-9_]*$", re.IGNORECASE) - def escape(self, colname: str): - colname = colname.replace("`", "") - if self.PLAIN_COLNAME_RX.match(colname): - return colname - return f"`{colname}`" - def _compile_to_sql( self, X: bpd.DataFrame, columns: Optional[Iterable[str]] = None ) -> List[str]: @@ -115,8 +111,10 @@ def _compile_to_sql( columns = X.columns result = [] for column in columns: - current_sql = self._sql.format(self.escape(column)) - current_target_column = self.escape(self._target_column.format(column)) + current_sql = self._sql.format(sql_utils.identifier(column)) + current_target_column = sql_utils.identifier( + self._target_column.format(column) + ) result.append(f"{current_sql} AS {current_target_column}") return result @@ -239,6 +237,7 @@ def camel_to_snake(name): transformers_set.add( ( camel_to_snake(transformer_cls.__name__), + # TODO: This is very fragile, use real SQL parser *transformer_cls._parse_from_sql(transform_sql), # type: ignore ) ) @@ -253,7 +252,7 @@ def camel_to_snake(name): target_column = transform_col_dict["name"] sql_transformer = SQLScalarColumnTransformer( - transform_sql, target_column=target_column + transform_sql.strip(), target_column=target_column ) input_column_name = f"?{target_column}" transformers_set.add( diff --git a/bigframes/ml/core.py b/bigframes/ml/core.py index 02ccc9d6a5..4bc61c5015 100644 --- a/bigframes/ml/core.py +++ b/bigframes/ml/core.py @@ -47,8 +47,10 @@ class BqmlModel(BaseBqml): def __init__(self, session: bigframes.Session, model: bigquery.Model): self._session = session self._model = model + model_ref = self._model.reference + assert model_ref is not None self._model_manipulation_sql_generator = ml_sql.ModelManipulationSqlGenerator( - self.model_name + model_ref ) def _apply_ml_tvf( diff --git a/bigframes/ml/impute.py b/bigframes/ml/impute.py index 4955eb5de5..dddade8cc5 100644 --- a/bigframes/ml/impute.py +++ b/bigframes/ml/impute.py @@ -80,7 +80,7 @@ def _parse_from_sql(cls, sql: str) -> tuple[SimpleImputer, str]: tuple(SimpleImputer, column_label)""" s = sql[sql.find("(") + 1 : sql.find(")")] col_label, strategy = s.split(", ") - return cls(strategy[1:-1]), col_label # type: ignore[arg-type] + return cls(strategy[1:-1]), _unescape_id(col_label) # type: ignore[arg-type] def fit( self, @@ -110,3 +110,11 @@ def transform(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame: bpd.DataFrame, df[self._output_names], ) + + +def _unescape_id(id: str) -> str: + """Very simple conversion to removed ` characters from ids. + + A proper sql parser should be used instead. + """ + return id.removeprefix("`").removesuffix("`") diff --git a/bigframes/ml/preprocessing.py b/bigframes/ml/preprocessing.py index 2c327f63f8..eb53904a78 100644 --- a/bigframes/ml/preprocessing.py +++ b/bigframes/ml/preprocessing.py @@ -76,7 +76,7 @@ def _parse_from_sql(cls, sql: str) -> tuple[StandardScaler, str]: Returns: tuple(StandardScaler, column_label)""" col_label = sql[sql.find("(") + 1 : sql.find(")")] - return cls(), col_label + return cls(), _unescape_id(col_label) def fit( self, @@ -152,8 +152,9 @@ def _parse_from_sql(cls, sql: str) -> tuple[MaxAbsScaler, str]: Returns: tuple(MaxAbsScaler, column_label)""" + # TODO: Use real sql parser col_label = sql[sql.find("(") + 1 : sql.find(")")] - return cls(), col_label + return cls(), _unescape_id(col_label) def fit( self, @@ -229,8 +230,9 @@ def _parse_from_sql(cls, sql: str) -> tuple[MinMaxScaler, str]: Returns: tuple(MinMaxScaler, column_label)""" + # TODO: Use real sql parser col_label = sql[sql.find("(") + 1 : sql.find(")")] - return cls(), col_label + return cls(), _unescape_id(col_label) def fit( self, @@ -349,11 +351,11 @@ def _parse_from_sql(cls, sql: str) -> tuple[KBinsDiscretizer, str]: if sql.startswith("ML.QUANTILE_BUCKETIZE"): num_bins = s.split(",")[1] - return cls(int(num_bins), "quantile"), col_label + return cls(int(num_bins), "quantile"), _unescape_id(col_label) else: array_split_points = s[s.find("[") + 1 : s.find("]")] n_bins = array_split_points.count(",") + 2 - return cls(n_bins, "uniform"), col_label + return cls(n_bins, "uniform"), _unescape_id(col_label) def fit( self, @@ -469,7 +471,7 @@ def _parse_from_sql(cls, sql: str) -> tuple[OneHotEncoder, str]: max_categories = int(top_k) + 1 min_frequency = int(frequency_threshold) - return cls(drop, min_frequency, max_categories), col_label + return cls(drop, min_frequency, max_categories), _unescape_id(col_label) def fit( self, @@ -578,7 +580,7 @@ def _parse_from_sql(cls, sql: str) -> tuple[LabelEncoder, str]: max_categories = int(top_k) + 1 min_frequency = int(frequency_threshold) - return cls(min_frequency, max_categories), col_label + return cls(min_frequency, max_categories), _unescape_id(col_label) def fit( self, @@ -661,7 +663,7 @@ def _parse_from_sql(cls, sql: str) -> tuple[PolynomialFeatures, tuple[str, ...]] col_labels = sql[sql.find("STRUCT(") + 7 : sql.find(")")].split(",") col_labels = [label.strip() for label in col_labels] degree = int(sql[sql.rfind(",") + 1 : sql.rfind(")")]) - return cls(degree), tuple(col_labels) + return cls(degree), tuple(map(_unescape_id, col_labels)) def fit( self, @@ -694,6 +696,14 @@ def transform(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame: ) +def _unescape_id(id: str) -> str: + """Very simple conversion to removed ` characters from ids. + + A proper sql parser should be used instead. + """ + return id.removeprefix("`").removesuffix("`") + + PreprocessingType = Union[ OneHotEncoder, StandardScaler, diff --git a/bigframes/ml/sql.py b/bigframes/ml/sql.py index 1cb327f19c..b7d550ac63 100644 --- a/bigframes/ml/sql.py +++ b/bigframes/ml/sql.py @@ -21,6 +21,9 @@ import bigframes_vendored.constants as constants import google.cloud.bigquery +import bigframes.core.compile.googlesql as sql_utils +import bigframes.core.sql as sql_vals + # TODO: Add proper escaping logic from core/compile module class BaseSqlGenerator: @@ -29,10 +32,8 @@ class BaseSqlGenerator: # General methods def encode_value(self, v: Union[str, int, float, Iterable[str]]) -> str: """Encode a parameter value for SQL""" - if isinstance(v, str): - return f'"{v}"' - elif isinstance(v, int) or isinstance(v, float): - return f"{v}" + if isinstance(v, (str, int, float)): + return sql_vals.simple_literal(v) elif isinstance(v, Iterable): inner = ", ".join([self.encode_value(x) for x in v]) return f"[{inner}]" @@ -50,7 +51,10 @@ def build_parameters(self, **kwargs: Union[str, int, float, Iterable[str]]) -> s def build_structs(self, **kwargs: Union[int, float]) -> str: """Encode a dict of values into a formatted STRUCT items for SQL""" indent_str = " " - param_strs = [f"{v} AS {k}" for k, v in kwargs.items()] + param_strs = [ + f"{sql_vals.simple_literal(v)} AS {sql_utils.identifier(k)}" + for k, v in kwargs.items() + ] return "\n" + indent_str + f",\n{indent_str}".join(param_strs) def build_expressions(self, *expr_sqls: str) -> str: @@ -61,7 +65,7 @@ def build_expressions(self, *expr_sqls: str) -> str: def build_schema(self, **kwargs: str) -> str: """Encode a dict of values into a formatted schema type items for SQL""" indent_str = " " - param_strs = [f"{k} {v}" for k, v in kwargs.items()] + param_strs = [f"{sql_utils.identifier(k)} {v}" for k, v in kwargs.items()] return "\n" + indent_str + f",\n{indent_str}".join(param_strs) def options(self, **kwargs: Union[str, int, float, Iterable[str]]) -> str: @@ -74,7 +78,7 @@ def struct_options(self, **kwargs: Union[int, float]) -> str: def struct_columns(self, columns: Iterable[str]) -> str: """Encode a BQ Table columns to a STRUCT.""" - columns_str = ", ".join(columns) + columns_str = ", ".join(map(sql_utils.identifier, columns)) return f"STRUCT({columns_str})" def input(self, **kwargs: str) -> str: @@ -97,30 +101,30 @@ def transform(self, *expr_sqls: str) -> str: def ml_standard_scaler(self, numeric_expr_sql: str, name: str) -> str: """Encode ML.STANDARD_SCALER for BQML""" - return f"""ML.STANDARD_SCALER({numeric_expr_sql}) OVER() AS {name}""" + return f"""ML.STANDARD_SCALER({sql_utils.identifier(numeric_expr_sql)}) OVER() AS {sql_utils.identifier(name)}""" def ml_max_abs_scaler(self, numeric_expr_sql: str, name: str) -> str: """Encode ML.MAX_ABS_SCALER for BQML""" - return f"""ML.MAX_ABS_SCALER({numeric_expr_sql}) OVER() AS {name}""" + return f"""ML.MAX_ABS_SCALER({sql_utils.identifier(numeric_expr_sql)}) OVER() AS {sql_utils.identifier(name)}""" def ml_min_max_scaler(self, numeric_expr_sql: str, name: str) -> str: """Encode ML.MIN_MAX_SCALER for BQML""" - return f"""ML.MIN_MAX_SCALER({numeric_expr_sql}) OVER() AS {name}""" + return f"""ML.MIN_MAX_SCALER({sql_utils.identifier(numeric_expr_sql)}) OVER() AS {sql_utils.identifier(name)}""" def ml_imputer( self, - expr_sql: str, + col_name: str, strategy: str, name: str, ) -> str: """Encode ML.IMPUTER for BQML""" - return f"""ML.IMPUTER({expr_sql}, '{strategy}') OVER() AS {name}""" + return f"""ML.IMPUTER({sql_utils.identifier(col_name)}, '{strategy}') OVER() AS {sql_utils.identifier(name)}""" def ml_bucketize( self, - numeric_expr_sql: str, + input_id: str, array_split_points: Iterable[Union[int, float]], - name: str, + output_id: str, ) -> str: """Encode ML.BUCKETIZE for BQML""" # Use Python value rather than Numpy value to serialization. @@ -128,7 +132,7 @@ def ml_bucketize( point.item() if hasattr(point, "item") else point for point in array_split_points ] - return f"""ML.BUCKETIZE({numeric_expr_sql}, {points}, FALSE) AS {name}""" + return f"""ML.BUCKETIZE({sql_utils.identifier(input_id)}, {points}, FALSE) AS {sql_utils.identifier(output_id)}""" def ml_quantile_bucketize( self, @@ -137,7 +141,7 @@ def ml_quantile_bucketize( name: str, ) -> str: """Encode ML.QUANTILE_BUCKETIZE for BQML""" - return f"""ML.QUANTILE_BUCKETIZE({numeric_expr_sql}, {num_bucket}) OVER() AS {name}""" + return f"""ML.QUANTILE_BUCKETIZE({sql_utils.identifier(numeric_expr_sql)}, {num_bucket}) OVER() AS {sql_utils.identifier(name)}""" def ml_one_hot_encoder( self, @@ -149,7 +153,7 @@ def ml_one_hot_encoder( ) -> str: """Encode ML.ONE_HOT_ENCODER for BQML. https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-one-hot-encoder for params.""" - return f"""ML.ONE_HOT_ENCODER({numeric_expr_sql}, '{drop}', {top_k}, {frequency_threshold}) OVER() AS {name}""" + return f"""ML.ONE_HOT_ENCODER({sql_utils.identifier(numeric_expr_sql)}, '{drop}', {top_k}, {frequency_threshold}) OVER() AS {sql_utils.identifier(name)}""" def ml_label_encoder( self, @@ -160,14 +164,14 @@ def ml_label_encoder( ) -> str: """Encode ML.LABEL_ENCODER for BQML. https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-label-encoder for params.""" - return f"""ML.LABEL_ENCODER({numeric_expr_sql}, {top_k}, {frequency_threshold}) OVER() AS {name}""" + return f"""ML.LABEL_ENCODER({sql_utils.identifier(numeric_expr_sql)}, {top_k}, {frequency_threshold}) OVER() AS {sql_utils.identifier(name)}""" def ml_polynomial_expand( self, columns: Iterable[str], degree: int, name: str ) -> str: """Encode ML.POLYNOMIAL_EXPAND. https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-polynomial-expand""" - return f"""ML.POLYNOMIAL_EXPAND({self.struct_columns(columns)}, {degree}) AS {name}""" + return f"""ML.POLYNOMIAL_EXPAND({self.struct_columns(columns)}, {degree}) AS {sql_utils.identifier(name)}""" def ml_distance( self, @@ -179,7 +183,7 @@ def ml_distance( ) -> str: """Encode ML.DISTANCE for BQML. https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-distance""" - return f"""SELECT *, ML.DISTANCE({col_x}, {col_y}, '{type}') AS {name} FROM ({source_sql})""" + return f"""SELECT *, ML.DISTANCE({sql_utils.identifier(col_x)}, {sql_utils.identifier(col_y)}, '{type}') AS {sql_utils.identifier(name)} FROM ({source_sql})""" class ModelCreationSqlGenerator(BaseSqlGenerator): @@ -189,7 +193,7 @@ def _model_id_sql( self, model_ref: google.cloud.bigquery.ModelReference, ): - return f"`{model_ref.project}`.`{model_ref.dataset_id}`.`{model_ref.model_id}`" + return f"{sql_utils.identifier(model_ref.project)}.{sql_utils.identifier(model_ref.dataset_id)}.{sql_utils.identifier(model_ref.model_id)}" # Model create and alter def create_model( @@ -276,8 +280,11 @@ def create_xgboost_imported_model( class ModelManipulationSqlGenerator(BaseSqlGenerator): """Sql generator for manipulating a model entity. Model name is the full model path of project_id.dataset_id.model_id.""" - def __init__(self, model_name: str): - self._model_name = model_name + def __init__(self, model_ref: google.cloud.bigquery.ModelReference): + self._model_ref = model_ref + + def _model_ref_sql(self) -> str: + return f"{sql_utils.identifier(self._model_ref.project)}.{sql_utils.identifier(self._model_ref.dataset_id)}.{sql_utils.identifier(self._model_ref.model_id)}" # Alter model def alter_model( @@ -287,20 +294,20 @@ def alter_model( """Encode the ALTER MODEL statement for BQML""" options_sql = self.options(**options) - parts = [f"ALTER MODEL `{self._model_name}`"] + parts = [f"ALTER MODEL {self._model_ref_sql()}"] parts.append(f"SET {options_sql}") return "\n".join(parts) # ML prediction TVFs def ml_predict(self, source_sql: str) -> str: """Encode ML.PREDICT for BQML""" - return f"""SELECT * FROM ML.PREDICT(MODEL `{self._model_name}`, + return f"""SELECT * FROM ML.PREDICT(MODEL {self._model_ref_sql()}, ({source_sql}))""" def ml_forecast(self, struct_options: Mapping[str, Union[int, float]]) -> str: """Encode ML.FORECAST for BQML""" struct_options_sql = self.struct_options(**struct_options) - return f"""SELECT * FROM ML.FORECAST(MODEL `{self._model_name}`, + return f"""SELECT * FROM ML.FORECAST(MODEL {self._model_ref_sql()}, {struct_options_sql})""" def ml_generate_text( @@ -308,7 +315,7 @@ def ml_generate_text( ) -> str: """Encode ML.GENERATE_TEXT for BQML""" struct_options_sql = self.struct_options(**struct_options) - return f"""SELECT * FROM ML.GENERATE_TEXT(MODEL `{self._model_name}`, + return f"""SELECT * FROM ML.GENERATE_TEXT(MODEL {self._model_ref_sql()}, ({source_sql}), {struct_options_sql})""" def ml_generate_embedding( @@ -316,7 +323,7 @@ def ml_generate_embedding( ) -> str: """Encode ML.GENERATE_EMBEDDING for BQML""" struct_options_sql = self.struct_options(**struct_options) - return f"""SELECT * FROM ML.GENERATE_EMBEDDING(MODEL `{self._model_name}`, + return f"""SELECT * FROM ML.GENERATE_EMBEDDING(MODEL {self._model_ref_sql()}, ({source_sql}), {struct_options_sql})""" def ml_detect_anomalies( @@ -324,51 +331,51 @@ def ml_detect_anomalies( ) -> str: """Encode ML.DETECT_ANOMALIES for BQML""" struct_options_sql = self.struct_options(**struct_options) - return f"""SELECT * FROM ML.DETECT_ANOMALIES(MODEL `{self._model_name}`, + return f"""SELECT * FROM ML.DETECT_ANOMALIES(MODEL {self._model_ref_sql()}, {struct_options_sql}, ({source_sql}))""" # ML evaluation TVFs def ml_evaluate(self, source_sql: Optional[str] = None) -> str: """Encode ML.EVALUATE for BQML""" if source_sql is None: - return f"""SELECT * FROM ML.EVALUATE(MODEL `{self._model_name}`)""" + return f"""SELECT * FROM ML.EVALUATE(MODEL {self._model_ref_sql()})""" else: - return f"""SELECT * FROM ML.EVALUATE(MODEL `{self._model_name}`, + return f"""SELECT * FROM ML.EVALUATE(MODEL {self._model_ref_sql()}, ({source_sql}))""" def ml_arima_coefficients(self) -> str: """Encode ML.ARIMA_COEFFICIENTS for BQML""" - return f"""SELECT * FROM ML.ARIMA_COEFFICIENTS(MODEL `{self._model_name}`)""" + return f"""SELECT * FROM ML.ARIMA_COEFFICIENTS(MODEL {self._model_ref_sql()})""" # ML evaluation TVFs def ml_llm_evaluate(self, source_sql: str, task_type: Optional[str] = None) -> str: """Encode ML.EVALUATE for BQML""" # Note: don't need index as evaluate returns a new table - return f"""SELECT * FROM ML.EVALUATE(MODEL `{self._model_name}`, + return f"""SELECT * FROM ML.EVALUATE(MODEL {self._model_ref_sql()}, ({source_sql}), STRUCT("{task_type}" AS task_type))""" # ML evaluation TVFs def ml_arima_evaluate(self, show_all_candidate_models: bool = False) -> str: """Encode ML.ARMIA_EVALUATE for BQML""" - return f"""SELECT * FROM ML.ARIMA_EVALUATE(MODEL `{self._model_name}`, + return f"""SELECT * FROM ML.ARIMA_EVALUATE(MODEL {self._model_ref_sql()}, STRUCT({show_all_candidate_models} AS show_all_candidate_models))""" def ml_centroids(self) -> str: """Encode ML.CENTROIDS for BQML""" - return f"""SELECT * FROM ML.CENTROIDS(MODEL `{self._model_name}`)""" + return f"""SELECT * FROM ML.CENTROIDS(MODEL {self._model_ref_sql()})""" def ml_principal_components(self) -> str: """Encode ML.PRINCIPAL_COMPONENTS for BQML""" - return f"""SELECT * FROM ML.PRINCIPAL_COMPONENTS(MODEL `{self._model_name}`)""" + return ( + f"""SELECT * FROM ML.PRINCIPAL_COMPONENTS(MODEL {self._model_ref_sql()})""" + ) def ml_principal_component_info(self) -> str: """Encode ML.PRINCIPAL_COMPONENT_INFO for BQML""" - return ( - f"""SELECT * FROM ML.PRINCIPAL_COMPONENT_INFO(MODEL `{self._model_name}`)""" - ) + return f"""SELECT * FROM ML.PRINCIPAL_COMPONENT_INFO(MODEL {self._model_ref_sql()})""" # ML transform TVF, that require a transform_only type model def ml_transform(self, source_sql: str) -> str: """Encode ML.TRANSFORM for BQML""" - return f"""SELECT * FROM ML.TRANSFORM(MODEL `{self._model_name}`, + return f"""SELECT * FROM ML.TRANSFORM(MODEL {self._model_ref_sql()}, ({source_sql}))""" diff --git a/tests/system/large/ml/test_compose.py b/tests/system/large/ml/test_compose.py index ba963837e5..cbc702018a 100644 --- a/tests/system/large/ml/test_compose.py +++ b/tests/system/large/ml/test_compose.py @@ -90,12 +90,14 @@ def test_columntransformer_standalone_fit_and_transform( def test_columntransformer_standalone_fit_transform(new_penguins_df): + # rename column to ensure robustness to column names that must be escaped + new_penguins_df = new_penguins_df.rename(columns={"species": "123 'species'"}) transformer = compose.ColumnTransformer( [ ( "onehot", preprocessing.OneHotEncoder(), - "species", + "123 'species'", ), ( "standard_scale", @@ -108,7 +110,7 @@ def test_columntransformer_standalone_fit_transform(new_penguins_df): "CASE WHEN {0} IS NULL THEN -1 ELSE LENGTH({0}) END", target_column="len_{0}", ), - "species", + "123 'species'", ), ( "identity", @@ -119,16 +121,16 @@ def test_columntransformer_standalone_fit_transform(new_penguins_df): ) result = transformer.fit_transform( - new_penguins_df[["species", "culmen_length_mm", "flipper_length_mm"]] + new_penguins_df[["123 'species'", "culmen_length_mm", "flipper_length_mm"]] ).to_pandas() utils.check_pandas_df_schema_and_index( result, columns=[ - "onehotencoded_species", + "onehotencoded_123 'species'", "standard_scaled_culmen_length_mm", "standard_scaled_flipper_length_mm", - "len_species", + "len_123 'species'", "culmen_length_mm", "flipper_length_mm", ], @@ -194,7 +196,7 @@ def test_columntransformer_save_load(new_penguins_df, dataset_id): ( "sql_scalar_column_transformer", compose.SQLScalarColumnTransformer( - "CASE WHEN species IS NULL THEN -1 ELSE LENGTH(species) END", + "CASE WHEN `species` IS NULL THEN -1 ELSE LENGTH(`species`) END", target_column="len_species", ), "?len_species", @@ -202,21 +204,21 @@ def test_columntransformer_save_load(new_penguins_df, dataset_id): ( "sql_scalar_column_transformer", compose.SQLScalarColumnTransformer( - "flipper_length_mm", target_column="flipper_length_mm" + "`flipper_length_mm`", target_column="flipper_length_mm" ), "?flipper_length_mm", ), ( "sql_scalar_column_transformer", compose.SQLScalarColumnTransformer( - "culmen_length_mm", target_column="culmen_length_mm" + "`culmen_length_mm`", target_column="culmen_length_mm" ), "?culmen_length_mm", ), ( "sql_scalar_column_transformer", compose.SQLScalarColumnTransformer( - "CASE WHEN species IS NULL THEN -1 ELSE LENGTH(species) END ", + "CASE WHEN `species` IS NULL THEN -1 ELSE LENGTH(`species`) END", target_column="Flex species Name", ), "?Flex species Name", diff --git a/tests/unit/core/test_bf_utils.py b/tests/unit/core/test_bf_utils.py index 10ce1fd09e..248b6796e2 100644 --- a/tests/unit/core/test_bf_utils.py +++ b/tests/unit/core/test_bf_utils.py @@ -26,7 +26,7 @@ def test_get_standardized_ids_columns(): utils.UNNAMED_COLUMN_ID, "duplicate", "duplicate_1", - "with_space", + "with space", ] assert idx_ids == [] @@ -35,7 +35,7 @@ def test_get_standardized_ids_indexes(): col_labels = ["duplicate"] idx_labels = ["string", 0, None, "duplicate", "duplicate", "with space"] - col_ids, idx_ids = utils.get_standardized_ids(col_labels, idx_labels) + col_ids, idx_ids = utils.get_standardized_ids(col_labels, idx_labels, strict=True) assert col_ids == ["duplicate_2"] assert idx_ids == [ @@ -53,4 +53,4 @@ def test_get_standardized_ids_tuple(): col_ids, _ = utils.get_standardized_ids(col_labels) - assert col_ids == ["('foo',_1)", "('foo',_2)", "('bar',_1)"] + assert col_ids == ["('foo', 1)", "('foo', 2)", "('bar', 1)"] diff --git a/tests/unit/ml/test_compose.py b/tests/unit/ml/test_compose.py index 7643f76e56..395296f3e4 100644 --- a/tests/unit/ml/test_compose.py +++ b/tests/unit/ml/test_compose.py @@ -258,8 +258,8 @@ def test_customtransformer_compile_sql(mock_X): ident_trafo = SQLScalarColumnTransformer("{0}", target_column="ident_{0}") sqls = ident_trafo._compile_to_sql(X=mock_X, columns=["col1", "col2"]) assert sqls == [ - "col1 AS ident_col1", - "col2 AS ident_col2", + "`col1` AS `ident_col1`", + "`col2` AS `ident_col2`", ] len1_trafo = SQLScalarColumnTransformer( @@ -267,8 +267,8 @@ def test_customtransformer_compile_sql(mock_X): ) sqls = len1_trafo._compile_to_sql(X=mock_X, columns=["col1", "col2"]) assert sqls == [ - "CASE WHEN col1 IS NULL THEN -5 ELSE LENGTH(col1) END AS len1_col1", - "CASE WHEN col2 IS NULL THEN -5 ELSE LENGTH(col2) END AS len1_col2", + "CASE WHEN `col1` IS NULL THEN -5 ELSE LENGTH(`col1`) END AS `len1_col1`", + "CASE WHEN `col2` IS NULL THEN -5 ELSE LENGTH(`col2`) END AS `len1_col2`", ] len2_trafo = SQLScalarColumnTransformer( @@ -276,8 +276,8 @@ def test_customtransformer_compile_sql(mock_X): ) sqls = len2_trafo._compile_to_sql(X=mock_X, columns=["col1", "col2"]) assert sqls == [ - "CASE WHEN col1 IS NULL THEN 99 ELSE LENGTH(col1) END AS len2_col1", - "CASE WHEN col2 IS NULL THEN 99 ELSE LENGTH(col2) END AS len2_col2", + "CASE WHEN `col1` IS NULL THEN 99 ELSE LENGTH(`col1`) END AS `len2_col1`", + "CASE WHEN `col2` IS NULL THEN 99 ELSE LENGTH(`col2`) END AS `len2_col2`", ] @@ -524,11 +524,11 @@ def test_columntransformer_compile_to_sql(mock_X): ) sqls = column_transformer._compile_to_sql(mock_X) assert sqls == [ - "culmen_length_mm AS ident_culmen_length_mm", - "flipper_length_mm AS ident_flipper_length_mm", - "CASE WHEN species IS NULL THEN -2 ELSE LENGTH(species) END AS len1_species", - "CASE WHEN species IS NULL THEN 99 ELSE LENGTH(species) END AS len2_species", - "ML.LABEL_ENCODER(species, 1000000, 0) OVER() AS labelencoded_species", + "`culmen_length_mm` AS `ident_culmen_length_mm`", + "`flipper_length_mm` AS `ident_flipper_length_mm`", + "CASE WHEN `species` IS NULL THEN -2 ELSE LENGTH(`species`) END AS `len1_species`", + "CASE WHEN `species` IS NULL THEN 99 ELSE LENGTH(`species`) END AS `len2_species`", + "ML.LABEL_ENCODER(`species`, 1000000, 0) OVER() AS `labelencoded_species`", ] @@ -548,13 +548,13 @@ def test_columntransformer_flexible_column_names(mock_X): ["culmen_length_mm", "flipper_length_mm"], ), ("len1_trafo", len1_transformer, ["species shortname"]), - ("len2_trafo", len2_transformer, ["`species longname`"]), + ("len2_trafo", len2_transformer, ["species longname"]), ] ) sqls = column_transformer._compile_to_sql(mock_X) assert sqls == [ - "culmen_length_mm AS `ident culmen_length_mm`", - "flipper_length_mm AS `ident flipper_length_mm`", + "`culmen_length_mm` AS `ident culmen_length_mm`", + "`flipper_length_mm` AS `ident flipper_length_mm`", "CASE WHEN `species shortname` IS NULL THEN -2 ELSE LENGTH(`species shortname`) END AS `len1_species shortname`", "CASE WHEN `species longname` IS NULL THEN 99 ELSE LENGTH(`species longname`) END AS `len2_species longname`", ] @@ -576,6 +576,6 @@ def test_columntransformer_extract_from_bq_model_flexnames(bq_model_flexnames): SQLScalarColumnTransformer(sql='culmen_length_mm', target_column='Flex Name culmen_length_mm'), '?Flex Name culmen_length_mm'), ('sql_scalar_column_transformer', - SQLScalarColumnTransformer(sql='flipper_length_mm ', target_column='Flex Name flipper_length_mm'), + SQLScalarColumnTransformer(sql='flipper_length_mm', target_column='Flex Name flipper_length_mm'), '?Flex Name flipper_length_mm')])""" assert expected == actual diff --git a/tests/unit/ml/test_golden_sql.py b/tests/unit/ml/test_golden_sql.py index aa7e919b24..65f079852e 100644 --- a/tests/unit/ml/test_golden_sql.py +++ b/tests/unit/ml/test_golden_sql.py @@ -106,7 +106,7 @@ def test_linear_regression_default_fit( model.fit(mock_X, mock_y) mock_session._start_query_ml_ddl.assert_called_once_with( - 'CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type="LINEAR_REG",\n data_split_method="NO_SPLIT",\n optimize_strategy="auto_strategy",\n fit_intercept=True,\n l2_reg=0.0,\n max_iterations=20,\n learn_rate_strategy="line_search",\n min_rel_progress=0.01,\n calculate_p_values=False,\n enable_global_explain=False,\n INPUT_LABEL_COLS=["input_column_label"])\nAS input_X_y_sql' + "CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type='LINEAR_REG',\n data_split_method='NO_SPLIT',\n optimize_strategy='auto_strategy',\n fit_intercept=True,\n l2_reg=0.0,\n max_iterations=20,\n learn_rate_strategy='line_search',\n min_rel_progress=0.01,\n calculate_p_values=False,\n enable_global_explain=False,\n INPUT_LABEL_COLS=['input_column_label'])\nAS input_X_y_sql" ) @@ -116,7 +116,7 @@ def test_linear_regression_params_fit(bqml_model_factory, mock_session, mock_X, model.fit(mock_X, mock_y) mock_session._start_query_ml_ddl.assert_called_once_with( - 'CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type="LINEAR_REG",\n data_split_method="NO_SPLIT",\n optimize_strategy="auto_strategy",\n fit_intercept=False,\n l2_reg=0.0,\n max_iterations=20,\n learn_rate_strategy="line_search",\n min_rel_progress=0.01,\n calculate_p_values=False,\n enable_global_explain=False,\n INPUT_LABEL_COLS=["input_column_label"])\nAS input_X_y_sql' + "CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type='LINEAR_REG',\n data_split_method='NO_SPLIT',\n optimize_strategy='auto_strategy',\n fit_intercept=False,\n l2_reg=0.0,\n max_iterations=20,\n learn_rate_strategy='line_search',\n min_rel_progress=0.01,\n calculate_p_values=False,\n enable_global_explain=False,\n INPUT_LABEL_COLS=['input_column_label'])\nAS input_X_y_sql" ) @@ -126,7 +126,7 @@ def test_linear_regression_predict(mock_session, bqml_model, mock_X): model.predict(mock_X) mock_session.read_gbq.assert_called_once_with( - "SELECT * FROM ML.PREDICT(MODEL `model_project.model_dataset.model_id`,\n (input_X_sql))", + "SELECT * FROM ML.PREDICT(MODEL `model_project`.`model_dataset`.`model_id`,\n (input_X_sql))", index_col=["index_column_id"], ) @@ -137,7 +137,7 @@ def test_linear_regression_score(mock_session, bqml_model, mock_X, mock_y): model.score(mock_X, mock_y) mock_session.read_gbq.assert_called_once_with( - "SELECT * FROM ML.EVALUATE(MODEL `model_project.model_dataset.model_id`,\n (input_X_y_sql))" + "SELECT * FROM ML.EVALUATE(MODEL `model_project`.`model_dataset`.`model_id`,\n (input_X_y_sql))" ) @@ -149,7 +149,7 @@ def test_logistic_regression_default_fit( model.fit(mock_X, mock_y) mock_session._start_query_ml_ddl.assert_called_once_with( - 'CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type="LOGISTIC_REG",\n data_split_method="NO_SPLIT",\n fit_intercept=True,\n auto_class_weights=False,\n optimize_strategy="auto_strategy",\n l2_reg=0.0,\n max_iterations=20,\n learn_rate_strategy="line_search",\n min_rel_progress=0.01,\n calculate_p_values=False,\n enable_global_explain=False,\n INPUT_LABEL_COLS=["input_column_label"])\nAS input_X_y_sql' + "CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type='LOGISTIC_REG',\n data_split_method='NO_SPLIT',\n fit_intercept=True,\n auto_class_weights=False,\n optimize_strategy='auto_strategy',\n l2_reg=0.0,\n max_iterations=20,\n learn_rate_strategy='line_search',\n min_rel_progress=0.01,\n calculate_p_values=False,\n enable_global_explain=False,\n INPUT_LABEL_COLS=['input_column_label'])\nAS input_X_y_sql" ) @@ -171,7 +171,7 @@ def test_logistic_regression_params_fit( model.fit(mock_X, mock_y) mock_session._start_query_ml_ddl.assert_called_once_with( - 'CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type="LOGISTIC_REG",\n data_split_method="NO_SPLIT",\n fit_intercept=False,\n auto_class_weights=True,\n optimize_strategy="batch_gradient_descent",\n l2_reg=0.2,\n max_iterations=30,\n learn_rate_strategy="constant",\n min_rel_progress=0.02,\n calculate_p_values=False,\n enable_global_explain=False,\n l1_reg=0.2,\n learn_rate=0.2,\n INPUT_LABEL_COLS=["input_column_label"])\nAS input_X_y_sql' + "CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type='LOGISTIC_REG',\n data_split_method='NO_SPLIT',\n fit_intercept=False,\n auto_class_weights=True,\n optimize_strategy='batch_gradient_descent',\n l2_reg=0.2,\n max_iterations=30,\n learn_rate_strategy='constant',\n min_rel_progress=0.02,\n calculate_p_values=False,\n enable_global_explain=False,\n l1_reg=0.2,\n learn_rate=0.2,\n INPUT_LABEL_COLS=['input_column_label'])\nAS input_X_y_sql" ) @@ -181,7 +181,7 @@ def test_logistic_regression_predict(mock_session, bqml_model, mock_X): model.predict(mock_X) mock_session.read_gbq.assert_called_once_with( - "SELECT * FROM ML.PREDICT(MODEL `model_project.model_dataset.model_id`,\n (input_X_sql))", + "SELECT * FROM ML.PREDICT(MODEL `model_project`.`model_dataset`.`model_id`,\n (input_X_sql))", index_col=["index_column_id"], ) @@ -192,5 +192,5 @@ def test_logistic_regression_score(mock_session, bqml_model, mock_X, mock_y): model.score(mock_X, mock_y) mock_session.read_gbq.assert_called_once_with( - "SELECT * FROM ML.EVALUATE(MODEL `model_project.model_dataset.model_id`,\n (input_X_y_sql))" + "SELECT * FROM ML.EVALUATE(MODEL `model_project`.`model_dataset`.`model_id`,\n (input_X_y_sql))" ) diff --git a/tests/unit/ml/test_sql.py b/tests/unit/ml/test_sql.py index cdf2d0b2e4..ee0821dfe9 100644 --- a/tests/unit/ml/test_sql.py +++ b/tests/unit/ml/test_sql.py @@ -34,7 +34,9 @@ def model_creation_sql_generator() -> ml_sql.ModelCreationSqlGenerator: @pytest.fixture(scope="session") def model_manipulation_sql_generator() -> ml_sql.ModelManipulationSqlGenerator: return ml_sql.ModelManipulationSqlGenerator( - model_name="my_project_id.my_dataset_id.my_model_id" + model_ref=bigquery.ModelReference.from_string( + "my_project_id.my_dataset_id.my_model_id" + ) ) @@ -53,7 +55,7 @@ def test_ml_arima_coefficients( sql = model_manipulation_sql_generator.ml_arima_coefficients() assert ( sql - == """SELECT * FROM ML.ARIMA_COEFFICIENTS(MODEL `my_project_id.my_dataset_id.my_model_id`)""" + == """SELECT * FROM ML.ARIMA_COEFFICIENTS(MODEL `my_project_id`.`my_dataset_id`.`my_model_id`)""" ) @@ -64,8 +66,8 @@ def test_options_correct(base_sql_generator: ml_sql.BaseSqlGenerator): assert ( sql == """OPTIONS( - model_type="lin_reg", - input_label_cols=["col_a"], + model_type='lin_reg', + input_label_cols=['col_a'], l1_reg=0.6)""" ) @@ -89,42 +91,42 @@ def test_standard_scaler_correct( base_sql_generator: ml_sql.BaseSqlGenerator, ): sql = base_sql_generator.ml_standard_scaler("col_a", "scaled_col_a") - assert sql == "ML.STANDARD_SCALER(col_a) OVER() AS scaled_col_a" + assert sql == "ML.STANDARD_SCALER(`col_a`) OVER() AS `scaled_col_a`" def test_max_abs_scaler_correct( base_sql_generator: ml_sql.BaseSqlGenerator, ): sql = base_sql_generator.ml_max_abs_scaler("col_a", "scaled_col_a") - assert sql == "ML.MAX_ABS_SCALER(col_a) OVER() AS scaled_col_a" + assert sql == "ML.MAX_ABS_SCALER(`col_a`) OVER() AS `scaled_col_a`" def test_min_max_scaler_correct( base_sql_generator: ml_sql.BaseSqlGenerator, ): sql = base_sql_generator.ml_min_max_scaler("col_a", "scaled_col_a") - assert sql == "ML.MIN_MAX_SCALER(col_a) OVER() AS scaled_col_a" + assert sql == "ML.MIN_MAX_SCALER(`col_a`) OVER() AS `scaled_col_a`" def test_imputer_correct( base_sql_generator: ml_sql.BaseSqlGenerator, ): sql = base_sql_generator.ml_imputer("col_a", "mean", "scaled_col_a") - assert sql == "ML.IMPUTER(col_a, 'mean') OVER() AS scaled_col_a" + assert sql == "ML.IMPUTER(`col_a`, 'mean') OVER() AS `scaled_col_a`" def test_k_bins_discretizer_correct( base_sql_generator: ml_sql.BaseSqlGenerator, ): sql = base_sql_generator.ml_bucketize("col_a", [1, 2, 3, 4], "scaled_col_a") - assert sql == "ML.BUCKETIZE(col_a, [1, 2, 3, 4], FALSE) AS scaled_col_a" + assert sql == "ML.BUCKETIZE(`col_a`, [1, 2, 3, 4], FALSE) AS `scaled_col_a`" def test_k_bins_discretizer_quantile_correct( base_sql_generator: ml_sql.BaseSqlGenerator, ): sql = base_sql_generator.ml_quantile_bucketize("col_a", 5, "scaled_col_a") - assert sql == "ML.QUANTILE_BUCKETIZE(col_a, 5) OVER() AS scaled_col_a" + assert sql == "ML.QUANTILE_BUCKETIZE(`col_a`, 5) OVER() AS `scaled_col_a`" def test_one_hot_encoder_correct( @@ -134,7 +136,8 @@ def test_one_hot_encoder_correct( "col_a", "none", 1000000, 0, "encoded_col_a" ) assert ( - sql == "ML.ONE_HOT_ENCODER(col_a, 'none', 1000000, 0) OVER() AS encoded_col_a" + sql + == "ML.ONE_HOT_ENCODER(`col_a`, 'none', 1000000, 0) OVER() AS `encoded_col_a`" ) @@ -142,14 +145,14 @@ def test_label_encoder_correct( base_sql_generator: ml_sql.BaseSqlGenerator, ): sql = base_sql_generator.ml_label_encoder("col_a", 1000000, 0, "encoded_col_a") - assert sql == "ML.LABEL_ENCODER(col_a, 1000000, 0) OVER() AS encoded_col_a" + assert sql == "ML.LABEL_ENCODER(`col_a`, 1000000, 0) OVER() AS `encoded_col_a`" def test_polynomial_expand( base_sql_generator: ml_sql.BaseSqlGenerator, ): sql = base_sql_generator.ml_polynomial_expand(["col_a", "col_b"], 2, "poly_exp") - assert sql == "ML.POLYNOMIAL_EXPAND(STRUCT(col_a, col_b), 2) AS poly_exp" + assert sql == "ML.POLYNOMIAL_EXPAND(STRUCT(`col_a`, `col_b`), 2) AS `poly_exp`" def test_create_model_correct( @@ -167,7 +170,7 @@ def test_create_model_correct( sql == """CREATE OR REPLACE MODEL `test-proj`.`_anonXYZ`.`create_model_correct_sql` OPTIONS( - option_key1="option_value1", + option_key1='option_value1', option_key2=2) AS input_X_y_sql""" ) @@ -195,7 +198,7 @@ def test_create_model_transform_correct( ML.STANDARD_SCALER(col_a) OVER(col_a) AS scaled_col_a, ML.ONE_HOT_ENCODER(col_b) OVER(col_b) AS encoded_col_b) OPTIONS( - option_key1="option_value1", + option_key1='option_value1', option_key2=2) AS input_X_y_sql""" ) @@ -218,7 +221,7 @@ def test_create_llm_remote_model_correct( == """CREATE OR REPLACE MODEL `test-proj`.`_anonXYZ`.`create_remote_model` REMOTE WITH CONNECTION `my_project.us.my_connection` OPTIONS( - option_key1="option_value1", + option_key1='option_value1', option_key2=2) AS input_X_y_sql""" ) @@ -239,7 +242,7 @@ def test_create_remote_model_correct( == """CREATE OR REPLACE MODEL `test-proj`.`_anonXYZ`.`create_remote_model` REMOTE WITH CONNECTION `my_project.us.my_connection` OPTIONS( - option_key1="option_value1", + option_key1='option_value1', option_key2=2)""" ) @@ -260,12 +263,12 @@ def test_create_remote_model_with_params_correct( sql == """CREATE OR REPLACE MODEL `test-proj`.`_anonXYZ`.`create_remote_model` INPUT( - column1 int64) + `column1` int64) OUTPUT( - result array) + `result` array) REMOTE WITH CONNECTION `my_project.us.my_connection` OPTIONS( - option_key1="option_value1", + option_key1='option_value1', option_key2=2)""" ) @@ -283,7 +286,7 @@ def test_create_imported_model_correct( sql == """CREATE OR REPLACE MODEL `test-proj`.`_anonXYZ`.`create_imported_model` OPTIONS( - option_key1="option_value1", + option_key1='option_value1', option_key2=2)""" ) @@ -303,11 +306,11 @@ def test_create_xgboost_imported_model_produces_correct_sql( sql == """CREATE OR REPLACE MODEL `test-proj`.`_anonXYZ`.`create_xgboost_imported_model` INPUT( - column1 int64) + `column1` int64) OUTPUT( - result array) + `result` array) OPTIONS( - option_key1="option_value1", + option_key1='option_value1', option_key2=2)""" ) @@ -320,9 +323,9 @@ def test_alter_model_correct_sql( ) assert ( sql - == """ALTER MODEL `my_project_id.my_dataset_id.my_model_id` + == """ALTER MODEL `my_project_id`.`my_dataset_id`.`my_model_id` SET OPTIONS( - option_key1="option_value1", + option_key1='option_value1', option_key2=2)""" ) @@ -334,7 +337,7 @@ def test_ml_predict_correct( sql = model_manipulation_sql_generator.ml_predict(source_sql=mock_df.sql) assert ( sql - == """SELECT * FROM ML.PREDICT(MODEL `my_project_id.my_dataset_id.my_model_id`, + == """SELECT * FROM ML.PREDICT(MODEL `my_project_id`.`my_dataset_id`.`my_model_id`, (input_X_y_sql))""" ) @@ -348,7 +351,7 @@ def test_ml_llm_evaluate_correct( ) assert ( sql - == """SELECT * FROM ML.EVALUATE(MODEL `my_project_id.my_dataset_id.my_model_id`, + == """SELECT * FROM ML.EVALUATE(MODEL `my_project_id`.`my_dataset_id`.`my_model_id`, (input_X_y_sql), STRUCT("CLASSIFICATION" AS task_type))""" ) @@ -360,7 +363,7 @@ def test_ml_evaluate_correct( sql = model_manipulation_sql_generator.ml_evaluate(source_sql=mock_df.sql) assert ( sql - == """SELECT * FROM ML.EVALUATE(MODEL `my_project_id.my_dataset_id.my_model_id`, + == """SELECT * FROM ML.EVALUATE(MODEL `my_project_id`.`my_dataset_id`.`my_model_id`, (input_X_y_sql))""" ) @@ -373,7 +376,7 @@ def test_ml_arima_evaluate_correct( ) assert ( sql - == """SELECT * FROM ML.ARIMA_EVALUATE(MODEL `my_project_id.my_dataset_id.my_model_id`, + == """SELECT * FROM ML.ARIMA_EVALUATE(MODEL `my_project_id`.`my_dataset_id`.`my_model_id`, STRUCT(True AS show_all_candidate_models))""" ) @@ -384,7 +387,7 @@ def test_ml_evaluate_no_source_correct( sql = model_manipulation_sql_generator.ml_evaluate() assert ( sql - == """SELECT * FROM ML.EVALUATE(MODEL `my_project_id.my_dataset_id.my_model_id`)""" + == """SELECT * FROM ML.EVALUATE(MODEL `my_project_id`.`my_dataset_id`.`my_model_id`)""" ) @@ -394,7 +397,7 @@ def test_ml_centroids_correct( sql = model_manipulation_sql_generator.ml_centroids() assert ( sql - == """SELECT * FROM ML.CENTROIDS(MODEL `my_project_id.my_dataset_id.my_model_id`)""" + == """SELECT * FROM ML.CENTROIDS(MODEL `my_project_id`.`my_dataset_id`.`my_model_id`)""" ) @@ -406,10 +409,10 @@ def test_ml_forecast_correct_sql( ) assert ( sql - == """SELECT * FROM ML.FORECAST(MODEL `my_project_id.my_dataset_id.my_model_id`, + == """SELECT * FROM ML.FORECAST(MODEL `my_project_id`.`my_dataset_id`.`my_model_id`, STRUCT( - 1 AS option_key1, - 2.2 AS option_key2))""" + 1 AS `option_key1`, + 2.2 AS `option_key2`))""" ) @@ -423,10 +426,10 @@ def test_ml_generate_text_correct( ) assert ( sql - == """SELECT * FROM ML.GENERATE_TEXT(MODEL `my_project_id.my_dataset_id.my_model_id`, + == """SELECT * FROM ML.GENERATE_TEXT(MODEL `my_project_id`.`my_dataset_id`.`my_model_id`, (input_X_y_sql), STRUCT( - 1 AS option_key1, - 2.2 AS option_key2))""" + 1 AS `option_key1`, + 2.2 AS `option_key2`))""" ) @@ -440,10 +443,10 @@ def test_ml_generate_embedding_correct( ) assert ( sql - == """SELECT * FROM ML.GENERATE_EMBEDDING(MODEL `my_project_id.my_dataset_id.my_model_id`, + == """SELECT * FROM ML.GENERATE_EMBEDDING(MODEL `my_project_id`.`my_dataset_id`.`my_model_id`, (input_X_y_sql), STRUCT( - 1 AS option_key1, - 2.2 AS option_key2))""" + 1 AS `option_key1`, + 2.2 AS `option_key2`))""" ) @@ -457,10 +460,10 @@ def test_ml_detect_anomalies_correct_sql( ) assert ( sql - == """SELECT * FROM ML.DETECT_ANOMALIES(MODEL `my_project_id.my_dataset_id.my_model_id`, + == """SELECT * FROM ML.DETECT_ANOMALIES(MODEL `my_project_id`.`my_dataset_id`.`my_model_id`, STRUCT( - 1 AS option_key1, - 2.2 AS option_key2), (input_X_y_sql))""" + 1 AS `option_key1`, + 2.2 AS `option_key2`), (input_X_y_sql))""" ) @@ -470,7 +473,7 @@ def test_ml_principal_components_correct( sql = model_manipulation_sql_generator.ml_principal_components() assert ( sql - == """SELECT * FROM ML.PRINCIPAL_COMPONENTS(MODEL `my_project_id.my_dataset_id.my_model_id`)""" + == """SELECT * FROM ML.PRINCIPAL_COMPONENTS(MODEL `my_project_id`.`my_dataset_id`.`my_model_id`)""" ) @@ -480,5 +483,5 @@ def test_ml_principal_component_info_correct( sql = model_manipulation_sql_generator.ml_principal_component_info() assert ( sql - == """SELECT * FROM ML.PRINCIPAL_COMPONENT_INFO(MODEL `my_project_id.my_dataset_id.my_model_id`)""" + == """SELECT * FROM ML.PRINCIPAL_COMPONENT_INFO(MODEL `my_project_id`.`my_dataset_id`.`my_model_id`)""" )