Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Add Batch Prediction to Gen AI SDK Intro #1595

Merged
merged 3 commits into from
Jan 3, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
213 changes: 204 additions & 9 deletions gemini/getting-started/intro_genai_sdk.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
"id": "JAPoU8Sm5E6e"
},
"source": [
"# Getting started with Google generative AI using the Gen AI SDK\n",
"# Getting started with Google Generative AI using the Gen AI SDK\n",
holtskinner marked this conversation as resolved.
Show resolved Hide resolved
holtskinner marked this conversation as resolved.
Show resolved Hide resolved
"\n",
"<table align=\"left\">\n",
" <td style=\"text-align: center\">\n",
Expand Down Expand Up @@ -117,6 +117,7 @@
"- Count tokens and compute tokens\n",
"- Use context caching\n",
"- Function calling\n",
"- Batch prediction\n",
holtskinner marked this conversation as resolved.
Show resolved Hide resolved
"- Get text embeddings\n"
]
},
Expand Down Expand Up @@ -146,7 +147,7 @@
},
"outputs": [],
"source": [
"%pip install --upgrade --user --quiet google-genai"
"%pip install --upgrade --user --quiet google-genai pandas"
holtskinner marked this conversation as resolved.
Show resolved Hide resolved
holtskinner marked this conversation as resolved.
Show resolved Hide resolved
]
},
{
Expand Down Expand Up @@ -187,14 +188,17 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 1,
"metadata": {
"id": "qgdSpVmDbdQ9"
},
"outputs": [],
"source": [
"import datetime\n",
holtskinner marked this conversation as resolved.
Show resolved Hide resolved
"\n",
"from google import genai\n",
holtskinner marked this conversation as resolved.
Show resolved Hide resolved
"from google.genai.types import (\n",
" CreateBatchJobConfig,\n",
" CreateCachedContentConfig,\n",
" EmbedContentConfig,\n",
" FunctionDeclaration,\n",
Expand All @@ -211,9 +215,9 @@
"id": "Ve4YBlDqzyj9"
},
"source": [
"## Connect to a generative AI API service\n",
"## Connect to a Generative AI API service\n",
"\n",
"Google's Gen AI APIs and models including Gemini are available in the following two API services:\n",
"Google Gen AI APIs and models including Gemini are available in the following two API services:\n",
"\n",
"- **[Google AI for Developers](https://ai.google.dev/gemini-api/docs)**: Experiment, prototype, and deploy small projects.\n",
"- **[Vertex AI](https://cloud.google.com/vertex-ai/generative-ai/docs/overview)**: Build enterprise-ready projects on Google Cloud.\n",
Expand Down Expand Up @@ -243,7 +247,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 2,
"metadata": {
"id": "Nqwi-5ufWp_B"
},
Expand All @@ -260,7 +264,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 3,
"metadata": {
"id": "T-tiytzQE0uM"
},
Expand All @@ -282,7 +286,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 4,
"metadata": {
"id": "-coEslfWPrxo"
},
Expand Down Expand Up @@ -1007,6 +1011,197 @@
"client.caches.delete(name=cached_content.name)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "43be33d2672b"
},
"source": [
"## Batch prediction\n",
"\n",
"Different from getting online (synchronous) responses, where you are limited to one input request at a time, [batch predictions for the Gemini API in Vertex AI](https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/batch-prediction-gemini) allow you to send a large number of requests to Gemini in a single batch request. Then, the model responses asynchronously populate to your storage output location in [Cloud Storage](https://cloud.google.com/storage/docs/introduction) or [BigQuery](https://cloud.google.com/bigquery/docs/storage_overview).\n",
"\n",
"Batch predictions are generally more efficient and cost-effective than online predictions when processing a large number of inputs that are not latency sensitive."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "adf948ae326b"
},
"source": [
"### Prepare batch inputs\n",
"\n",
"The input for batch requests specifies the items to send to your model for prediction.\n",
"\n",
"Batch requests for Gemini accept BigQuery storage sources and Cloud Storage sources. You can learn more about the batch input formats in the [Batch text generation](https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/batch-prediction-gemini#prepare_your_inputs) page.\n",
"\n",
"This tutorial uses Cloud Storage as an example. The requirements for Cloud Storage input are:\n",
"\n",
"- File format: [JSON Lines (JSONL)](https://jsonlines.org/)\n",
"- Located in `us-central1`\n",
"- Appropriate read permissions for the service account\n",
"\n",
"Each request that you send to a model can include parameters that control how the model generates a response. Learn more about Gemini parameters in the [Experiment with parameter values](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/prompts/adjust-parameter-values) page.\n",
"\n",
"This is one of the example requests in the input JSONL file `batch_requests_for_multimodal_input_2.jsonl`:\n",
"\n",
"```json\n",
"{\"request\":{\"contents\": [{\"role\": \"user\", \"parts\": [{\"text\": \"List objects in this image.\"}, {\"file_data\": {\"file_uri\": \"gs://cloud-samples-data/generative-ai/image/office-desk.jpeg\", \"mime_type\": \"image/jpeg\"}}]}],\"generationConfig\":{\"temperature\": 0.4}}}\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"id": "81b25154a51a"
},
"outputs": [],
"source": [
"INPUT_DATA = \"gs://cloud-samples-data/generative-ai/batch/batch_requests_for_multimodal_input_2.jsonl\" # @param {type:\"string\"}"
holtskinner marked this conversation as resolved.
Show resolved Hide resolved
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2031bb3f44c2"
},
"source": [
"### Prepare batch output location\n",
"\n",
"When a batch prediction task completes, the output is stored in the location that you specified in your request.\n",
"\n",
"- The location is in the form of a Cloud Storage or BigQuery URI prefix, for example:\n",
"`gs://path/to/output/data` or `bq://projectId.bqDatasetId`.\n",
"\n",
"- If not specified, `gs://STAGING_BUCKET/gen-ai-batch-prediction` will be used for Cloud Storage source and `bq://PROJECT_ID.gen_ai_batch_prediction.predictions_TIMESTAMP` will be used for BigQuery source.\n",
"\n",
"This tutorial uses a Cloud Storage bucket as an example for the output location.\n",
"\n",
"- You can specify the URI of your Cloud Storage bucket in `BUCKET_URI`, or\n",
"- if it is not specified, a new Cloud Storage bucket in the form of `gs://PROJECT_ID-TIMESTAMP` will be created for you."
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"id": "fddd98cd84cd"
},
"outputs": [],
"source": [
"BUCKET_URI = \"[your-cloud-storage-bucket]\" # @param {type:\"string\"}\n",
holtskinner marked this conversation as resolved.
Show resolved Hide resolved
holtskinner marked this conversation as resolved.
Show resolved Hide resolved
"\n",
"if BUCKET_URI == \"[your-cloud-storage-bucket]\":\n",
" TIMESTAMP = datetime.now().strftime(\"%Y%m%d%H%M%S\")\n",
" BUCKET_URI = f\"gs://{PROJECT_ID}-{TIMESTAMP}\"\n",
"\n",
" ! gsutil mb -l {LOCATION} -p {PROJECT_ID} {BUCKET_URI}"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "d7da62c98880"
},
"source": [
"### Send a batch prediction request\n",
"\n",
"To make a batch prediction request, you specify a source model ID, an input source and an output location where Vertex AI stores the batch prediction results.\n",
"\n",
"To learn more, see the [Batch prediction API](https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/batch-prediction-api) page.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "7ed3c2925663"
},
"outputs": [],
"source": [
"batch_job = client.batches.create(\n",
" model=\"gemini-1.5-flash-002\",\n",
" src=INPUT_DATA,\n",
" config=CreateBatchJobConfig(dest=BUCKET_URI),\n",
")\n",
"batch_job.name"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "f1bd49ff2c9e"
},
"source": [
"Print out the job status and other properties. You can also check the status in the Cloud Console at https://console.cloud.google.com/vertex-ai/batch-predictions"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ee2ec586e4f1"
},
"outputs": [],
"source": [
"client.batches.get(name=batch_job.name)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "64eaf082ecb0"
},
"source": [
"Optionally, you can list all the batch prediction jobs in the project."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "da8e9d43a89b"
},
"outputs": [],
"source": [
"for job in client.batches.list():\n",
" print(job.name, job.create_time, job.state)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0156eaf66675"
},
"source": [
"### Retrieve batch prediction results\n",
"\n",
"When a batch prediction task is complete, the output of the prediction is stored in the location that you specified in your request. It is also available in `batch_job.dest.bigquery_uri` or `batch_job.dest.gcs_uri`.\n",
"\n",
"Example output:\n",
"\n",
"```json\n",
"{\"status\": \"\", \"processed_time\": \"2024-11-13T14:04:28.376+00:00\", \"request\": {\"contents\": [{\"parts\": [{\"file_data\": null, \"text\": \"List objects in this image.\"}, {\"file_data\": {\"file_uri\": \"gs://cloud-samples-data/generative-ai/image/gardening-tools.jpeg\", \"mime_type\": \"image/jpeg\"}, \"text\": null}], \"role\": \"user\"}], \"generationConfig\": {\"temperature\": 0.4}}, \"response\": {\"candidates\": [{\"avgLogprobs\": -0.10394711927934126, \"content\": {\"parts\": [{\"text\": \"Here's a list of the objects in the image:\\n\\n* **Watering can:** A green plastic watering can with a white rose head.\\n* **Plant:** A small plant (possibly oregano) in a terracotta pot.\\n* **Terracotta pots:** Two terracotta pots, one containing the plant and another empty, stacked on top of each other.\\n* **Gardening gloves:** A pair of striped gardening gloves.\\n* **Gardening tools:** A small trowel and a hand cultivator (hoe). Both are green with black handles.\"}], \"role\": \"model\"}, \"finishReason\": \"STOP\"}], \"modelVersion\": \"gemini-1.5-flash-002@default\", \"usageMetadata\": {\"candidatesTokenCount\": 110, \"promptTokenCount\": 264, \"totalTokenCount\": 374}}}\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "c2ce0968112c"
},
"outputs": [],
"source": [
"import pandas as pd\n",
"\n",
"# Load the JSONL file into a DataFrame\n",
"df = pd.read_json(f\"{batch_job.dest.gcs_uri}/*/predictions.jsonl\", lines=True)\n",
"df"
]
},
{
"cell_type": "markdown",
"metadata": {
Expand All @@ -1015,7 +1210,7 @@
"source": [
"## Get text embeddings\n",
"\n",
"You can get text embeddings for a snippet of text by using `embed_content` method. All models produce an output with 768 dimensions by default. However, some models give users the option to choose an output dimensionality between 1 and 768. See [Vertex AI text embeddings API](https://cloud.google.com/vertex-ai/generative-ai/docs/embeddings/get-text-embeddings) for more details."
"You can get text embeddings for a snippet of text by using `embed_content` method. All models produce an output with 768 dimensions by default. However, some models give users the option to choose an output dimensionality between `1` and `768`. See [Vertex AI text embeddings API](https://cloud.google.com/vertex-ai/generative-ai/docs/embeddings/get-text-embeddings) for more details."
]
},
{
Expand Down
Loading