diff --git a/samples/snippets/classification_boosted_tree_model_test.py b/samples/snippets/classification_boosted_tree_model_test.py index 707ce16279..715c9d1098 100644 --- a/samples/snippets/classification_boosted_tree_model_test.py +++ b/samples/snippets/classification_boosted_tree_model_test.py @@ -48,19 +48,42 @@ def test_boosted_tree_model(random_model_id: str) -> None: y = training_data["income_bracket"] # create and train the model - census_model = ensemble.XGBClassifier( + tree_model = ensemble.XGBClassifier( n_estimators=1, booster="gbtree", tree_method="hist", max_iterations=1, # For a more accurate model, try 50 iterations. subsample=0.85, ) - census_model.fit(X, y) + tree_model.fit(X, y) - census_model.to_gbq( - your_model_id, # For example: "your-project.census.census_model" + tree_model.to_gbq( + your_model_id, # For example: "your-project.bqml_tutorial.tree_model" replace=True, ) # [END bigquery_dataframes_bqml_boosted_tree_create] + # [START bigquery_dataframes_bqml_boosted_tree_explain] + # Select model you'll use for predictions. `read_gbq_model` loads model + # data from BigQuery, but you could also use the `tree_model` object + # from the previous step. + tree_model = bpd.read_gbq_model( + your_model_id, # For example: "your-project.bqml_tutorial.tree_model" + ) + + # input_data is defined in an earlier step. + evaluation_data = input_data[input_data["dataframe"] == "evaluation"] + X = evaluation_data.drop(columns=["income_bracket", "dataframe"]) + y = evaluation_data["income_bracket"] + + # The score() method evaluates how the model performs compared to the + # actual data. Output DataFrame matches that of ML.EVALUATE(). + score = tree_model.score(X, y) + score.peek() + # Output: + # precision recall accuracy f1_score log_loss roc_auc + # 0 0.671924 0.578804 0.839429 0.621897 0.344054 0.887335 + # [END bigquery_dataframes_bqml_boosted_tree_explain] + assert tree_model is not None + assert evaluation_data is not None + assert score is not None assert input_data is not None - assert census_model is not None