Skip to content

Commit

Permalink
feat: add seq2seq forecasting training job (#1196)
Browse files Browse the repository at this point in the history
Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly:
- [x] Make sure to open an issue as a [bug/issue](https://github.com/googleapis/python-aiplatform/issues/new/choose) before writing your code!  That way we can discuss the change, evaluate designs, and agree on the general idea
- [x] Ensure the tests and linter pass
- [x] Code coverage does not decrease (if any source code was changed)
- [x] Appropriate docs were updated (if necessary)

Fixes b/229909845 🦕

---

Adds a `SequenceToSequencePlusForecastingTrainingJob` to training jobs. This job has the exact same signature as `AutoMLForecastingTrainingJob`, but we are creating a separate job in case the two models diverge in the future.

The logic for `AutoMLForecastingTrainingJob` has been moved to a new abstract base class `_ForecastingTrainingJob`. The only things that differ between the seq2seq and automl training jobs that extend it are the `model_type` and `training_task_definition`.
  • Loading branch information
TheMichaelHu authored Jun 3, 2022
1 parent efaf6ed commit 643d335
Show file tree
Hide file tree
Showing 5 changed files with 1,890 additions and 1,619 deletions.
2 changes: 2 additions & 0 deletions google/cloud/aiplatform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
CustomPythonPackageTrainingJob,
AutoMLTabularTrainingJob,
AutoMLForecastingTrainingJob,
SequenceToSequencePlusForecastingTrainingJob,
AutoMLImageTrainingJob,
AutoMLTextTrainingJob,
AutoMLVideoTrainingJob,
Expand Down Expand Up @@ -116,6 +117,7 @@
"Model",
"ModelEvaluation",
"PipelineJob",
"SequenceToSequencePlusForecastingTrainingJob",
"TabularDataset",
"Tensorboard",
"TensorboardExperiment",
Expand Down
1 change: 1 addition & 0 deletions google/cloud/aiplatform/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class definition:
custom_task = "gs://google-cloud-aiplatform/schema/trainingjob/definition/custom_task_1.0.0.yaml"
automl_tabular = "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_tabular_1.0.0.yaml"
automl_forecasting = "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_time_series_forecasting_1.0.0.yaml"
seq2seq_plus_forecasting = "gs://google-cloud-aiplatform/schema/trainingjob/definition/seq2seq_plus_time_series_forecasting_1.0.0.yaml"
automl_image_classification = "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_image_classification_1.0.0.yaml"
automl_image_object_detection = "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_image_object_detection_1.0.0.yaml"
automl_text_classification = "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_text_classification_1.0.0.yaml"
Expand Down
Loading

0 comments on commit 643d335

Please sign in to comment.