diff --git a/notebooks/ml/bq_dataframes_ml_cross_validation.ipynb b/notebooks/ml/bq_dataframes_ml_cross_validation.ipynb new file mode 100644 index 0000000000..824d911aff --- /dev/null +++ b/notebooks/ml/bq_dataframes_ml_cross_validation.ipynb @@ -0,0 +1,1105 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# BigFrames ML Cross-Vaidation\n", + "\n", + "This demo shows how to do cross validation in bigframes.ml" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Prepare Data" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import bigframes.pandas as bpd" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/google/home/garrettwu/src/bigframes/venv/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3577: UserWarning: Reading cached table from 2024-10-01 22:44:50.650768+00:00 to avoid incompatibilies with previous reads of this table. To read the latest version, set `use_cache=False` or close the current session with Session.close() or bigframes.pandas.close_session().\n", + " exec(code_obj, self.user_global_ns, self.user_ns)\n" + ] + }, + { + "data": { + "text/html": [ + "Query job 4c2f2252-687a-47c3-87ad-22db8ad96e2b is DONE. 0 Bytes processed. Open Job" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Query job a05c7268-8db2-468b-9fb4-0fb5c9534f51 is DONE. 0 Bytes processed. Open Job" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
speciesislandculmen_length_mmculmen_depth_mmflipper_length_mmbody_mass_gsex
0Gentoo penguin (Pygoscelis papua)Biscoe50.515.9225.05400.0MALE
1Gentoo penguin (Pygoscelis papua)Biscoe45.114.5215.05000.0FEMALE
2Adelie Penguin (Pygoscelis adeliae)Torgersen41.418.5202.03875.0MALE
3Adelie Penguin (Pygoscelis adeliae)Torgersen38.617.0188.02900.0FEMALE
4Gentoo penguin (Pygoscelis papua)Biscoe46.514.8217.05200.0FEMALE
........................
339Adelie Penguin (Pygoscelis adeliae)Dream38.117.6187.03425.0FEMALE
340Adelie Penguin (Pygoscelis adeliae)Biscoe36.417.1184.02850.0FEMALE
341Chinstrap penguin (Pygoscelis antarctica)Dream40.916.6187.03200.0FEMALE
342Adelie Penguin (Pygoscelis adeliae)Biscoe41.321.1195.04400.0MALE
343Chinstrap penguin (Pygoscelis antarctica)Dream45.216.6191.03250.0FEMALE
\n", + "

334 rows × 7 columns

\n", + "
[334 rows x 7 columns in total]" + ], + "text/plain": [ + " species island culmen_length_mm \\\n", + "0 Gentoo penguin (Pygoscelis papua) Biscoe 50.5 \n", + "1 Gentoo penguin (Pygoscelis papua) Biscoe 45.1 \n", + "2 Adelie Penguin (Pygoscelis adeliae) Torgersen 41.4 \n", + "3 Adelie Penguin (Pygoscelis adeliae) Torgersen 38.6 \n", + "4 Gentoo penguin (Pygoscelis papua) Biscoe 46.5 \n", + ".. ... ... ... \n", + "339 Adelie Penguin (Pygoscelis adeliae) Dream 38.1 \n", + "340 Adelie Penguin (Pygoscelis adeliae) Biscoe 36.4 \n", + "341 Chinstrap penguin (Pygoscelis antarctica) Dream 40.9 \n", + "342 Adelie Penguin (Pygoscelis adeliae) Biscoe 41.3 \n", + "343 Chinstrap penguin (Pygoscelis antarctica) Dream 45.2 \n", + "\n", + " culmen_depth_mm flipper_length_mm body_mass_g sex \n", + "0 15.9 225.0 5400.0 MALE \n", + "1 14.5 215.0 5000.0 FEMALE \n", + "2 18.5 202.0 3875.0 MALE \n", + "3 17.0 188.0 2900.0 FEMALE \n", + "4 14.8 217.0 5200.0 FEMALE \n", + ".. ... ... ... ... \n", + "339 17.6 187.0 3425.0 FEMALE \n", + "340 17.1 184.0 2850.0 FEMALE \n", + "341 16.6 187.0 3200.0 FEMALE \n", + "342 21.1 195.0 4400.0 MALE \n", + "343 16.6 191.0 3250.0 FEMALE \n", + "...\n", + "\n", + "[334 rows x 7 columns]" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# read and filter out unavailable data\n", + "df = bpd.read_gbq(\"bigframes-dev.bqml_tutorial.penguins\")\n", + "df = df.dropna()\n", + "df" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "# Select X and y from the dataset\n", + "X = df[\n", + " [\n", + " \"species\",\n", + " \"island\",\n", + " \"culmen_length_mm\",\n", + " ]\n", + " ]\n", + "y = df[\"body_mass_g\"]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2.1 Define KFold class and Train/Test for Each Fold (Mauanl Approach)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "from bigframes.ml import model_selection, linear_model" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "# Create KFold instance, n_splits defines how many folds the data will split. For example, n_split=5 will split the entire dataset into 5 pieces. \n", + "# In each fold, 4 pieces will be used for training, and the other piece will be used for evaluation. \n", + "kf = model_selection.KFold(n_splits=5)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "Query job 582e7c02-bcc6-412a-a513-46ee5dba7ad8 is DONE. 2.7 kB processed. Open Job" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Query job 917ff09b-072b-4c55-b26f-1780e2e97519 is DONE. 25.9 kB processed. Open Job" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Query job 2f4e102d-48bc-401f-a781-39830e2c6c9b is DONE. 16.4 kB processed. Open Job" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Query job aabe8a28-8dce-4e00-8a8c-18e9e090e6e7 is DONE. 26.3 kB processed. Open Job" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Query job ec9d8798-e28e-44bc-aa8e-44ab28f0214f is DONE. 48 Bytes processed. Open Job" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Query job 8aa0fa94-e43e-41c6-9de3-f0a67392c47f is DONE. 48 Bytes processed. Open Job" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + " mean_absolute_error mean_squared_error mean_squared_log_error \\\n", + "0 318.358226 151689.571141 0.009814 \n", + "\n", + " median_absolute_error r2_score explained_variance \n", + "0 255.095561 0.780659 0.783304 \n", + "\n", + "[1 rows x 6 columns]\n" + ] + }, + { + "data": { + "text/html": [ + "Query job bf6ef937-9583-4aa8-8313-563638465d5f is DONE. 25.9 kB processed. Open Job" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Query job 4c8b564c-5bbd-4447-babf-e307524962e5 is DONE. 16.4 kB processed. Open Job" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Query job cd5e337f-6d44-473d-a90b-be8a79bba6bf is DONE. 26.3 kB processed. Open Job" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Query job ad80012d-7c6c-4dbf-9271-2ff7f899f174 is DONE. 48 Bytes processed. Open Job" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Query job 8fc20587-d8ba-4c0f-bed9-3e1cf3c6ae52 is DONE. 48 Bytes processed. Open Job" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + " mean_absolute_error mean_squared_error mean_squared_log_error \\\n", + "0 306.435423 151573.84019 0.008539 \n", + "\n", + " median_absolute_error r2_score explained_variance \n", + "0 244.2899 0.737623 0.742859 \n", + "\n", + "[1 rows x 6 columns]\n" + ] + }, + { + "data": { + "text/html": [ + "Query job 90286d2b-e805-4b19-8876-c9973579e9ff is DONE. 25.9 kB processed. Open Job" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Query job ceb6c8f2-16cc-4758-bde8-3e4975ba1452 is DONE. 16.4 kB processed. Open Job" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Query job f49434fa-a7e0-406a-bbe2-5651595e3418 is DONE. 26.3 kB processed. Open Job" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Query job 5dd7a277-10fe-4117-a354-ef8668a8b913 is DONE. 48 Bytes processed. Open Job" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Query job 4b58b016-9a50-4a66-b86c-8431faad43bf is DONE. 48 Bytes processed. Open Job" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + " mean_absolute_error mean_squared_error mean_squared_log_error \\\n", + "0 253.349578 112039.741164 0.007153 \n", + "\n", + " median_absolute_error r2_score explained_variance \n", + "0 185.916761 0.823381 0.823456 \n", + "\n", + "[1 rows x 6 columns]\n" + ] + }, + { + "data": { + "text/html": [ + "Query job ca700ecf-0c08-4286-b979-2bc7a0bee89c is DONE. 25.9 kB processed. Open Job" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Query job f0731e71-7754-47a2-a553-93a61e712533 is DONE. 16.4 kB processed. Open Job" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Query job ae66d34d-5f0a-4297-9d41-57067ae54a9b is DONE. 26.3 kB processed. Open Job" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Query job 7655a649-ceca-4792-b764-fb371f5872ec is DONE. 48 Bytes processed. Open Job" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Query job 8b0634c8-73a9-422c-9644-842142dbb059 is DONE. 48 Bytes processed. Open Job" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + " mean_absolute_error mean_squared_error mean_squared_log_error \\\n", + "0 320.381386 155234.800349 0.008638 \n", + "\n", + " median_absolute_error r2_score explained_variance \n", + "0 306.281263 0.793405 0.794504 \n", + "\n", + "[1 rows x 6 columns]\n" + ] + }, + { + "data": { + "text/html": [ + "Query job bb26cde9-1991-4e0a-8492-b19d15b1b7aa is DONE. 25.9 kB processed. Open Job" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Query job 7ddd0883-492d-46bc-a588-f3cbab2474bb is DONE. 16.5 kB processed. Open Job" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Query job 5de571e4-d2f9-43c7-b014-3d65a3731b64 is DONE. 26.3 kB processed. Open Job" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Query job d20ac7d8-cd21-4a1f-a200-2dfa6373bcdb is DONE. 48 Bytes processed. Open Job" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Query job 235e8a80-33ea-4a95-a7d0-34e40a8ca396 is DONE. 48 Bytes processed. Open Job" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + " mean_absolute_error mean_squared_error mean_squared_log_error \\\n", + "0 303.855563 141869.030392 0.008989 \n", + "\n", + " median_absolute_error r2_score explained_variance \n", + "0 245.102301 0.731737 0.732793 \n", + "\n", + "[1 rows x 6 columns]\n" + ] + } + ], + "source": [ + "for X_train, X_test, y_train, y_test in kf.split(X, y):\n", + " model = linear_model.LinearRegression()\n", + " model.fit(X_train, y_train)\n", + " score = model.score(X_test, y_test)\n", + "\n", + " print(score)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2.2 Use cross_validate Function to Do Cross Validation (Automatic Approach)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "Query job 9274ae2e-e9a7-4701-ac64-56632323d02a is DONE. 0 Bytes processed. Open Job" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Query job 22f9477b-de02-4c07-b480-c3270a69d7e0 is DONE. 25.9 kB processed. Open Job" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Query job ebb192b7-4a9e-4238-b4e6-b630e2f94988 is DONE. 16.5 kB processed. Open Job" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Query job 44441e8c-8753-41b0-b1b7-9a6c4eab8c74 is DONE. 26.3 kB processed. Open Job" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Query job 239fed9a-b488-47da-a0df-a3b7c6ec40f4 is DONE. 25.9 kB processed. Open Job" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Query job f4248b2d-3430-426c-872d-8590f2878366 is DONE. 16.4 kB processed. Open Job" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Query job d9f6b034-c300-4dd7-91dd-48fa912f2456 is DONE. 26.3 kB processed. Open Job" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Query job e2f39f5b-2f4c-402a-a8d5-a7cff918508d is DONE. 25.9 kB processed. Open Job" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Query job 54cf3710-b5f4-4aec-b11f-0281126a151a is DONE. 16.4 kB processed. Open Job" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Query job 833d13cd-ec59-499b-98f6-95ec18766698 is DONE. 26.3 kB processed. Open Job" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Query job 0120e332-0691-44a4-9198-f5c131b8f59c is DONE. 25.9 kB processed. Open Job" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Query job f4ba7a4c-5fd9-4f97-ab34-a8f139e7472a is DONE. 16.4 kB processed. Open Job" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Query job 857aadfc-2ade-429c-bef8-428e44d48c55 is DONE. 26.3 kB processed. Open Job" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Query job 906d6d34-a506-4957-b07f-7e5ed2e0634b is DONE. 25.9 kB processed. Open Job" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Query job 498563db-3e68-4df7-a2d5-83da6adb49ed is DONE. 16.5 kB processed. Open Job" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Query job 01af95ca-6288-4253-b379-7327e1c9de88 is DONE. 26.3 kB processed. Open Job" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Query job 5ce36d32-6db1-42e5-a8cf-84bb8244a57e is DONE. 48 Bytes processed. Open Job" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Query job e05ec77d-6025-4edd-b5e3-9c4e7a124e71 is DONE. 48 Bytes processed. Open Job" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Query job 418a4a5d-2bb3-41e5-9e7c-9852389a491b is DONE. 48 Bytes processed. Open Job" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Query job b33e30da-cfed-4d6f-b227-f433d97879cb is DONE. 48 Bytes processed. Open Job" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Query job 7ad7f0c8-ecae-4ef2-bc91-0ebeb5f88e7b is DONE. 48 Bytes processed. Open Job" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Query job a6e8bd12-1122-4c26-b0e1-58342238016c is DONE. 48 Bytes processed. Open Job" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Query job c553439c-9586-479c-92c5-01a0d333125b is DONE. 48 Bytes processed. Open Job" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Query job c598d64c-26b9-49fc-afad-a6544b38cfa2 is DONE. 48 Bytes processed. Open Job" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Query job ebcb73e8-1294-4f10-b826-c495046fd714 is DONE. 48 Bytes processed. Open Job" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Query job d73f57ba-a25d-4b90-b474-13d81a3e22ab is DONE. 48 Bytes processed. Open Job" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "{'test_score': [ mean_absolute_error mean_squared_error mean_squared_log_error \\\n", + " 0 237.154735 97636.17064 0.005571 \n", + " \n", + " median_absolute_error r2_score explained_variance \n", + " 0 187.883888 0.842018 0.846816 \n", + " \n", + " [1 rows x 6 columns],\n", + " mean_absolute_error mean_squared_error mean_squared_log_error \\\n", + " 0 304.281635 141966.045867 0.008064 \n", + " \n", + " median_absolute_error r2_score explained_variance \n", + " 0 236.096453 0.762979 0.764008 \n", + " \n", + " [1 rows x 6 columns],\n", + " mean_absolute_error mean_squared_error mean_squared_log_error \\\n", + " 0 316.380322 157332.146085 0.009699 \n", + " \n", + " median_absolute_error r2_score explained_variance \n", + " 0 222.824496 0.764607 0.765369 \n", + " \n", + " [1 rows x 6 columns],\n", + " mean_absolute_error mean_squared_error mean_squared_log_error \\\n", + " 0 309.609657 152421.826588 0.009772 \n", + " \n", + " median_absolute_error r2_score explained_variance \n", + " 0 254.163976 0.772954 0.773119 \n", + " \n", + " [1 rows x 6 columns],\n", + " mean_absolute_error mean_squared_error mean_squared_log_error \\\n", + " 0 339.339345 169760.629993 0.010597 \n", + " \n", + " median_absolute_error r2_score explained_variance \n", + " 0 312.335706 0.741167 0.74118 \n", + " \n", + " [1 rows x 6 columns]],\n", + " 'fit_time': [18.200648623984307,\n", + " 17.565149880945683,\n", + " 18.202434757025912,\n", + " 18.04062689607963,\n", + " 19.370970834977925],\n", + " 'score_time': [4.76077218609862,\n", + " 4.577479084953666,\n", + " 4.581933492794633,\n", + " 4.741644307971001,\n", + " 5.1031754210125655]}" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# By using model_selection.cross_validate, the above 2.1 process is automated. The returned scores contains the evaluation results for each fold.\n", + "model = linear_model.LinearRegression()\n", + "scores = model_selection.cross_validate(model, X, y, cv=5)\n", + "scores" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/notebooks/regression/bq_dataframes_ml_linear_regression.ipynb b/notebooks/ml/bq_dataframes_ml_linear_regression.ipynb similarity index 100% rename from notebooks/regression/bq_dataframes_ml_linear_regression.ipynb rename to notebooks/ml/bq_dataframes_ml_linear_regression.ipynb diff --git a/notebooks/regression/easy_linear_regression.ipynb b/notebooks/ml/easy_linear_regression.ipynb similarity index 100% rename from notebooks/regression/easy_linear_regression.ipynb rename to notebooks/ml/easy_linear_regression.ipynb diff --git a/notebooks/regression/sklearn_linear_regression.ipynb b/notebooks/ml/sklearn_linear_regression.ipynb similarity index 100% rename from notebooks/regression/sklearn_linear_regression.ipynb rename to notebooks/ml/sklearn_linear_regression.ipynb diff --git a/noxfile.py b/noxfile.py index c704da00a5..714c8333bd 100644 --- a/noxfile.py +++ b/noxfile.py @@ -731,7 +731,7 @@ def notebook(session: nox.Session): # appropriate values and omitting cleanup logic that may break # our test infrastructure. "notebooks/getting_started/ml_fundamentals_bq_dataframes.ipynb", # Needs DATASET. - "notebooks/regression/bq_dataframes_ml_linear_regression.ipynb", # Needs DATASET_ID. + "notebooks/ml/bq_dataframes_ml_linear_regression.ipynb", # Needs DATASET_ID. "notebooks/generative_ai/bq_dataframes_ml_drug_name_generation.ipynb", # Needs CONNECTION. # TODO(b/332737009): investigate why we get 404 errors, even though # bq_dataframes_llm_code_generation creates a bucket in the sample.