Skip to content

Commit

Permalink
feat: LVM - Removed the width and height parameters from `ImageGe…
Browse files Browse the repository at this point in the history
…nerationModel.generate_images` since the service has dropped support for image sizes and aspect ratios

PiperOrigin-RevId: 558246815
  • Loading branch information
Ark-kun authored and copybara-github committed Aug 18, 2023
1 parent ce60cf7 commit 52897e6
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 27 deletions.
25 changes: 15 additions & 10 deletions tests/system/aiplatform/test_vision_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,9 @@ def test_image_generation_model_generate_images(self):
"imagegeneration@001"
)

width = 1024
height = 768
# TODO(b/295946075): The service stopped supporting image sizes.
# width = 1024
# height = 768
number_of_images = 4
seed = 1
guidance_scale = 15
Expand All @@ -104,20 +105,23 @@ def test_image_generation_model_generate_images(self):
# Optional:
negative_prompt=negative_prompt1,
number_of_images=number_of_images,
width=width,
height=height,
# TODO(b/295946075): The service stopped supporting image sizes.
# width=width,
# height=height,
seed=seed,
guidance_scale=guidance_scale,
)

assert len(image_response.images) == number_of_images
for idx, image in enumerate(image_response):
assert image._pil_image.size == (width, height)
# TODO(b/295946075): The service stopped supporting image sizes.
# assert image._pil_image.size == (width, height)
assert image.generation_parameters
assert image.generation_parameters["prompt"] == prompt1
assert image.generation_parameters["negative_prompt"] == negative_prompt1
assert image.generation_parameters["width"] == width
assert image.generation_parameters["height"] == height
# TODO(b/295946075): The service stopped supporting image sizes.
# assert image.generation_parameters["width"] == width
# assert image.generation_parameters["height"] == height
assert image.generation_parameters["seed"] == seed
assert image.generation_parameters["guidance_scale"] == guidance_scale
assert image.generation_parameters["index_of_image_in_batch"] == idx
Expand All @@ -127,13 +131,13 @@ def test_image_generation_model_generate_images(self):
image_path = os.path.join(temp_dir, "image.png")
image_response[0].save(location=image_path)
image1 = vision_models.GeneratedImage.load_from_file(image_path)
assert image1._pil_image.size == (width, height)
# assert image1._pil_image.size == (width, height)
assert image1.generation_parameters
assert image1.generation_parameters["prompt"] == prompt1

# Preparing mask
mask_path = os.path.join(temp_dir, "mask.png")
mask_pil_image = PIL_Image.new(mode="RGB", size=(width, height))
mask_pil_image = PIL_Image.new(mode="RGB", size=image1._pil_image.size)
mask_pil_image.save(mask_path, format="PNG")
mask_image = vision_models.Image.load_from_file(mask_path)

Expand All @@ -150,7 +154,8 @@ def test_image_generation_model_generate_images(self):
)
assert len(image_response2.images) == number_of_images
for idx, image in enumerate(image_response2):
assert image._pil_image.size == (width, height)
# TODO(b/295946075): The service stopped supporting image sizes.
# assert image._pil_image.size == (width, height)
assert image.generation_parameters
assert image.generation_parameters["prompt"] == prompt2
assert image.generation_parameters["seed"] == seed
Expand Down
25 changes: 16 additions & 9 deletions tests/unit/aiplatform/test_vision_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import os
import tempfile
from typing import Any, Dict
import unittest
from unittest import mock

from google.cloud import aiplatform
Expand Down Expand Up @@ -175,7 +176,9 @@ def test_generate_images(self):
model = self._get_image_generation_model()

width = 1024
height = 768
# TODO(b/295946075) The service stopped supporting image sizes.
# height = 768
height = 1024
number_of_images = 4
seed = 1
guidance_scale = 15
Expand All @@ -200,8 +203,9 @@ def test_generate_images(self):
# Optional:
negative_prompt=negative_prompt1,
number_of_images=number_of_images,
width=width,
height=height,
# TODO(b/295946075) The service stopped supporting image sizes.
# width=width,
# height=height,
seed=seed,
guidance_scale=guidance_scale,
)
Expand All @@ -210,8 +214,9 @@ def test_generate_images(self):
actual_instance = predict_kwargs["instances"][0]
assert actual_instance["prompt"] == prompt1
assert actual_instance["negativePrompt"] == negative_prompt1
assert actual_parameters["sampleImageSize"] == str(max(width, height))
assert actual_parameters["aspectRatio"] == f"{width}:{height}"
# TODO(b/295946075) The service stopped supporting image sizes.
# assert actual_parameters["sampleImageSize"] == str(max(width, height))
# assert actual_parameters["aspectRatio"] == f"{width}:{height}"
assert actual_parameters["seed"] == seed
assert actual_parameters["guidanceScale"] == guidance_scale

Expand All @@ -221,8 +226,9 @@ def test_generate_images(self):
assert image.generation_parameters
assert image.generation_parameters["prompt"] == prompt1
assert image.generation_parameters["negative_prompt"] == negative_prompt1
assert image.generation_parameters["width"] == width
assert image.generation_parameters["height"] == height
# TODO(b/295946075) The service stopped supporting image sizes.
# assert image.generation_parameters["width"] == width
# assert image.generation_parameters["height"] == height
assert image.generation_parameters["seed"] == seed
assert image.generation_parameters["guidance_scale"] == guidance_scale
assert image.generation_parameters["index_of_image_in_batch"] == idx
Expand All @@ -233,13 +239,13 @@ def test_generate_images(self):
image_path = os.path.join(temp_dir, "image.png")
image_response[0].save(location=image_path)
image1 = vision_models.GeneratedImage.load_from_file(image_path)
assert image1._pil_image.size == (width, height)
# assert image1._pil_image.size == (width, height)
assert image1.generation_parameters
assert image1.generation_parameters["prompt"] == prompt1

# Preparing mask
mask_path = os.path.join(temp_dir, "mask.png")
mask_pil_image = PIL_Image.new(mode="RGB", size=(width, height))
mask_pil_image = PIL_Image.new(mode="RGB", size=image1._pil_image.size)
mask_pil_image.save(mask_path, format="PNG")
mask_image = vision_models.Image.load_from_file(mask_path)

Expand Down Expand Up @@ -273,6 +279,7 @@ def test_generate_images(self):
assert image.generation_parameters["base_image_hash"]
assert image.generation_parameters["mask_hash"]

@unittest.skip(reason="b/295946075 The service stopped supporting image sizes.")
def test_generate_images_requests_square_images_by_default(self):
"""Tests that the model class generates square image by default."""
model = self._get_image_generation_model()
Expand Down
14 changes: 6 additions & 8 deletions vertexai/vision_models/_vision_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,9 @@ def _generate_images(
instance = {"prompt": prompt}
shared_generation_parameters = {
"prompt": prompt,
"width": width,
"height": height,
# b/295946075 The service stopped supporting image sizes.
# "width": width,
# "height": height,
"number_of_images_in_batch": number_of_images,
}

Expand Down Expand Up @@ -238,8 +239,6 @@ def generate_images(
*,
negative_prompt: Optional[str] = None,
number_of_images: int = 1,
width: Optional[int] = None,
height: Optional[int] = None,
guidance_scale: Optional[float] = None,
seed: Optional[int] = None,
) -> "ImageGenerationResponse":
Expand All @@ -250,8 +249,6 @@ def generate_images(
negative_prompt: A description of what you want to omit in
the generated images.
number_of_images: Number of images to generate. Range: 1..8.
width: Width of the image. One of the sizes must be 256 or 1024.
height: Height of the image. One of the sizes must be 256 or 1024.
guidance_scale: Controls the strength of the prompt.
Suggested values are:
* 0-9 (low strength)
Expand All @@ -266,8 +263,9 @@ def generate_images(
prompt=prompt,
negative_prompt=negative_prompt,
number_of_images=number_of_images,
width=width,
height=height,
# b/295946075 The service stopped supporting image sizes.
width=None,
height=None,
guidance_scale=guidance_scale,
seed=seed,
)
Expand Down

0 comments on commit 52897e6

Please sign in to comment.