diff --git a/bigframes/ml/forecasting.py b/bigframes/ml/forecasting.py index 6079e0ea22..9dc1652912 100644 --- a/bigframes/ml/forecasting.py +++ b/bigframes/ml/forecasting.py @@ -199,6 +199,8 @@ def _fit( Returns: ARIMAPlus: Fitted estimator. """ + X, y = utils.batch_convert_to_dataframe(X, y) + if X.columns.size != 1: raise ValueError( "Time series timestamp input X must only contain 1 column." @@ -206,8 +208,6 @@ def _fit( if y.columns.size != 1: raise ValueError("Time series data input y must only contain 1 column.") - X, y = utils.batch_convert_to_dataframe(X, y) - self._bqml_model = self._bqml_model_factory.create_time_series_model( X, y, diff --git a/tests/system/large/ml/test_forecasting.py b/tests/system/large/ml/test_forecasting.py index 438177b1a0..bb53305b94 100644 --- a/tests/system/large/ml/test_forecasting.py +++ b/tests/system/large/ml/test_forecasting.py @@ -36,7 +36,7 @@ @pytest.fixture(scope="module") def arima_model(time_series_df_default_index): model = forecasting.ARIMAPlus() - X_train = time_series_df_default_index[["parsed_date"]] + X_train = time_series_df_default_index["parsed_date"] y_train = time_series_df_default_index[["total_visits"]] model.fit(X_train, y_train) return model @@ -114,7 +114,7 @@ def test_arima_plus_model_fit_params(time_series_df_default_index, dataset_id): ) X_train = time_series_df_default_index[["parsed_date"]] - y_train = time_series_df_default_index[["total_visits"]] + y_train = time_series_df_default_index["total_visits"] model.fit(X_train, y_train) # save, load to ensure configuration was kept