Skip to content

Commit

Permalink
feat(components): Make model_checkpoint optional for `preview.llm.i…
Browse files Browse the repository at this point in the history
…nfer_pipeline`

PiperOrigin-RevId: 574876480
  • Loading branch information
Googler committed Oct 19, 2023
1 parent d8a0660 commit e8fb699
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
1 change: 1 addition & 0 deletions components/google-cloud/RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
* Add sliced evaluation metrics support for custom and unstructured AutoML models in evaluation pipeline and evaluation pipeline with feature attribution.
* Support `service_account` in `ModelBatchPredictOp`.
* Release `DataflowFlexTemplateJobOp` to GA namespace (`v1.dataflow.DataflowFlexTemplateJobOp`).
* Make `model_checkpoint` optional for `preview.llm.infer_pipeline`. If not provided, the base model associated with the `large_model_reference` will be used.

## Release 2.4.1
* Disable caching for LLM pipeline tasks that store temporary artifacts.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@
)
def infer_pipeline(
large_model_reference: str,
model_checkpoint: str,
prompt_dataset: str,
model_checkpoint: Optional[str] = None,
prompt_sequence_length: int = 512,
target_sequence_length: int = 64,
sampling_strategy: str = 'greedy',
Expand All @@ -47,7 +47,7 @@ def infer_pipeline(
Args:
large_model_reference: Name of the base model. Supported values are `text-bison@001`, `t5-small`, `t5-large`, `t5-xl` and `t5-xxl`. `text-bison@001` and `t5-small` are supported in `us-central1` and `europe-west4`. `t5-large`, `t5-xl` and `t5-xxl` are only supported in `europe-west4`.
model_checkpoint: Cloud storage path to the model checkpoint.
model_checkpoint: Optional Cloud storage path to the model checkpoint. If not provided, the default checkpoint for the `large_model_reference` will be used.
prompt_dataset: Cloud storage path to an unlabled prompt dataset used for reinforcement learning. The dataset format is jsonl. Each example in the dataset must have an `input_text` field that contains the prompt.
prompt_sequence_length: Maximum tokenized sequence length for input text. Higher values increase memory overhead. This value should be at most 8192. Default value is 512.
target_sequence_length: Maximum tokenized sequence length for target text. Higher values increase memory overhead. This value should be at most 1024. Default value is 64.
Expand All @@ -66,7 +66,8 @@ def infer_pipeline(
use_test_spec=env.get_use_test_machine_spec(),
).set_display_name('Resolve Machine Spec')
reference_model_metadata = function_based.resolve_reference_model_metadata(
large_model_reference=large_model_reference
large_model_reference=large_model_reference,
reference_model_path=model_checkpoint,
).set_display_name('Resolve Model Metadata')

prompt_dataset_image_uri = function_based.resolve_private_image_uri(
Expand Down Expand Up @@ -98,7 +99,7 @@ def infer_pipeline(
bulk_inference = bulk_inferrer.BulkInferrer(
project=project,
location=location,
input_model=model_checkpoint,
input_model=reference_model_metadata.outputs['reference_model_path'],
input_dataset_path=prompt_dataset_importer.outputs['imported_data_path'],
dataset_split=env.TRAIN_SPLIT,
inputs_sequence_length=prompt_sequence_length,
Expand Down

0 comments on commit e8fb699

Please sign in to comment.