diff --git a/components/google-cloud/RELEASE.md b/components/google-cloud/RELEASE.md index 38b84d88bf0..fc40e16bf96 100644 --- a/components/google-cloud/RELEASE.md +++ b/components/google-cloud/RELEASE.md @@ -6,6 +6,7 @@ * Add sliced evaluation metrics support for custom and unstructured AutoML models in evaluation pipeline and evaluation pipeline with feature attribution. * Support `service_account` in `ModelBatchPredictOp`. * Release `DataflowFlexTemplateJobOp` to GA namespace (`v1.dataflow.DataflowFlexTemplateJobOp`). +* Make `model_checkpoint` optional for `preview.llm.infer_pipeline`. If not provided, the base model associated with the `large_model_reference` will be used. ## Release 2.4.1 * Disable caching for LLM pipeline tasks that store temporary artifacts. diff --git a/components/google-cloud/google_cloud_pipeline_components/preview/llm/infer/component.py b/components/google-cloud/google_cloud_pipeline_components/preview/llm/infer/component.py index cfa0f715568..429916e4346 100644 --- a/components/google-cloud/google_cloud_pipeline_components/preview/llm/infer/component.py +++ b/components/google-cloud/google_cloud_pipeline_components/preview/llm/infer/component.py @@ -33,8 +33,8 @@ ) def infer_pipeline( large_model_reference: str, - model_checkpoint: str, prompt_dataset: str, + model_checkpoint: Optional[str] = None, prompt_sequence_length: int = 512, target_sequence_length: int = 64, sampling_strategy: str = 'greedy', @@ -47,7 +47,7 @@ def infer_pipeline( Args: large_model_reference: Name of the base model. Supported values are `text-bison@001`, `t5-small`, `t5-large`, `t5-xl` and `t5-xxl`. `text-bison@001` and `t5-small` are supported in `us-central1` and `europe-west4`. `t5-large`, `t5-xl` and `t5-xxl` are only supported in `europe-west4`. - model_checkpoint: Cloud storage path to the model checkpoint. + model_checkpoint: Optional Cloud storage path to the model checkpoint. If not provided, the default checkpoint for the `large_model_reference` will be used. prompt_dataset: Cloud storage path to an unlabled prompt dataset used for reinforcement learning. The dataset format is jsonl. Each example in the dataset must have an `input_text` field that contains the prompt. prompt_sequence_length: Maximum tokenized sequence length for input text. Higher values increase memory overhead. This value should be at most 8192. Default value is 512. target_sequence_length: Maximum tokenized sequence length for target text. Higher values increase memory overhead. This value should be at most 1024. Default value is 64. @@ -66,7 +66,8 @@ def infer_pipeline( use_test_spec=env.get_use_test_machine_spec(), ).set_display_name('Resolve Machine Spec') reference_model_metadata = function_based.resolve_reference_model_metadata( - large_model_reference=large_model_reference + large_model_reference=large_model_reference, + reference_model_path=model_checkpoint, ).set_display_name('Resolve Model Metadata') prompt_dataset_image_uri = function_based.resolve_private_image_uri( @@ -98,7 +99,7 @@ def infer_pipeline( bulk_inference = bulk_inferrer.BulkInferrer( project=project, location=location, - input_model=model_checkpoint, + input_model=reference_model_metadata.outputs['reference_model_path'], input_dataset_path=prompt_dataset_importer.outputs['imported_data_path'], dataset_split=env.TRAIN_SPLIT, inputs_sequence_length=prompt_sequence_length,