Skip to content

Commit

Permalink
fix: arima model series input.
Browse files Browse the repository at this point in the history
  • Loading branch information
Genesis929 committed Dec 20, 2024
1 parent d87ab97 commit 9539f5d
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions bigframes/ml/forecasting.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,15 +199,15 @@ def _fit(
Returns:
ARIMAPlus: Fitted estimator.
"""
X, y = utils.batch_convert_to_dataframe(X, y)

if X.columns.size != 1:
raise ValueError(
"Time series timestamp input X must only contain 1 column."
)
if y.columns.size != 1:
raise ValueError("Time series data input y must only contain 1 column.")

X, y = utils.batch_convert_to_dataframe(X, y)

self._bqml_model = self._bqml_model_factory.create_time_series_model(
X,
y,
Expand Down
4 changes: 2 additions & 2 deletions tests/system/large/ml/test_forecasting.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
@pytest.fixture(scope="module")
def arima_model(time_series_df_default_index):
model = forecasting.ARIMAPlus()
X_train = time_series_df_default_index[["parsed_date"]]
X_train = time_series_df_default_index["parsed_date"]
y_train = time_series_df_default_index[["total_visits"]]
model.fit(X_train, y_train)
return model
Expand Down Expand Up @@ -114,7 +114,7 @@ def test_arima_plus_model_fit_params(time_series_df_default_index, dataset_id):
)

X_train = time_series_df_default_index[["parsed_date"]]
y_train = time_series_df_default_index[["total_visits"]]
y_train = time_series_df_default_index["total_visits"]
model.fit(X_train, y_train)

# save, load to ensure configuration was kept
Expand Down

0 comments on commit 9539f5d

Please sign in to comment.