Skip to content

Commit

Permalink
feat: Vision Models - onboard Image Segmentation.
Browse files Browse the repository at this point in the history
Generates masks by segmenting a base image. Supports several different modes and input modalities, along with parameters to customize the prediction response.

More information is available in the model card at https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/image-segmentation-001

PiperOrigin-RevId: 686264882
  • Loading branch information
vertex-sdk-bot authored and copybara-github committed Oct 15, 2024
1 parent 507e988 commit ae63a43
Show file tree
Hide file tree
Showing 3 changed files with 291 additions and 0 deletions.
75 changes: 75 additions & 0 deletions tests/unit/aiplatform/test_vision_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,19 @@
},
}

_IMAGE_SEGMENTATION_PUBLISHER_MODEL_DICT = {
"name": "publishers/google/models/image-segmentation-001",
"version_id": "default",
"open_source_category": "PROPRIETARY",
"launch_stage": (gca_publisher_model.PublisherModel.LaunchStage.PRIVATE_PREVIEW),
"publisher_model_template": "projects/{project}/locations/{location}/publishers/google/models/image-segmentation-001",
"predict_schemata": {
"instance_schema_uri": "gs://google-cloud-aiplatform/schema/predict/instance/image_segmentation_model_1.0.0.yaml",
"parameters_schema_uri": "gs://google-cloud-aiplatfrom/schema/predict/params/image_segmentation_model_1.0.0.yaml",
"prediction_schema_uri": "gs://google-cloud-aiplatform/schema/predict/prediction/image_segmentation_model_1.0.0.yaml",
},
}


def make_image_base64(width: int, height: int) -> str:
image: PIL_Image.Image = PIL_Image.new(mode="RGB", size=(width, height))
Expand Down Expand Up @@ -173,6 +186,20 @@ def make_image_upscale_response_gcs() -> Dict[str, Any]:
return {"predictions": [predictions]}


def make_image_segmentation_response(
width: int, height: int, count: int = 1
) -> Dict[str, Any]:
predictions = []
for _ in range(count):
predictions.append(
{
"bytesBase64Encoded": make_image_base64(width, height),
"mimeType": "image/png",
}
)
return {"predictions": predictions}


def generate_image_from_file(
width: int = 100, height: int = 100
) -> ga_vision_models.Image:
Expand Down Expand Up @@ -1018,6 +1045,54 @@ def test_get_image_verification_results(self):
assert actual_results == [gca_prediction_response, "REJECT"]


@pytest.mark.usefixtures("google_auth_mock")
class ImageSegmentationModelTests:
"""Unit tests for the image segmentation models."""

def setup_method(self):
importlib.reload(initializer)
importlib.reload(aiplatform)

def teardown_method(self):
initializer.global_pool.shutdown(wait=True)

def test_get_image_segmentation_results(self):
"""Tests the image segmentation model."""
aiplatform.init(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
)
with mock.patch.object(
target=model_garden_service_client.ModelGardenServiceClient,
attribute="get_publisher_model",
return_value=gca_publisher_model.PublisherModel(
_IMAGE_SEGMENTATION_PUBLISHER_MODEL_DICT
),
) as mock_get_publisher_model:
model = ga_vision_models.ImageSegmentationModel.from_pretrained(
"image-segmentation-001"
)
mock_get_publisher_model.assert_called_once_with(
name="publishers/google/models/image-segmentation-001",
retry=base._DEFAULT_RETRY,
)

image = generate_image_from_file()
image_segmentation_response = make_image_segmentation_response(640, 640)
gca_prediction_response = gca_prediction_service.PredictResponse()
gca_prediction_response.predictions.append(
image_segmentation_response["predictions"]
)

with mock.patch.object(
target=prediction_service_client.PredictionServiceClient,
attribute="predict",
return_value=gca_prediction_response,
):
segmentation_response = model.segment_image(base_image=image)
assert len(segmentation_response) == 1


@pytest.mark.usefixtures("google_auth_mock")
class TestMultiModalEmbeddingModels:
"""Unit tests for the image generation models."""
Expand Down
10 changes: 10 additions & 0 deletions vertexai/preview/vision_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,20 @@
"""Classes for working with vision models."""

from vertexai.vision_models._vision_models import (
EntityLabel,
GeneratedImage,
GeneratedMask,
Image,
ImageCaptioningModel,
ImageGenerationModel,
ImageGenerationResponse,
ImageQnAModel,
ImageSegmentationModel,
ImageSegmentationResponse,
ImageTextModel,
MultiModalEmbeddingModel,
MultiModalEmbeddingResponse,
Scribble,
Video,
VideoEmbedding,
VideoSegmentConfig,
Expand All @@ -32,16 +37,21 @@
)

__all__ = [
"EntityLabel",
"GeneratedMask",
"Image",
"ImageGenerationModel",
"ImageGenerationResponse",
"ImageCaptioningModel",
"ImageQnAModel",
"ImageSegmentationModel",
"ImageSegmentationResponse",
"ImageTextModel",
"WatermarkVerificationModel",
"GeneratedImage",
"MultiModalEmbeddingModel",
"MultiModalEmbeddingResponse",
"Scribble",
"Video",
"VideoEmbedding",
"VideoSegmentConfig",
Expand Down
206 changes: 206 additions & 0 deletions vertexai/vision_models/_vision_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1398,3 +1398,209 @@ def verify_image(self, image: Image) -> WatermarkVerificationResponse:
_prediction_response=response,
watermark_verification_result=verification_likelihood,
)


class Scribble:
"""Input scribble for image segmentation."""

__module__ = "vertexai.preview.vision_models"

_image_: Optional[Image] = None

def __init__(
self,
image_bytes: Optional[bytes],
gcs_uri: Optional[str] = None,
):
"""Creates a `Scribble` object.
Args:
image_bytes: Mask image file bytes.
gcs_uri: Mask image file Google Cloud Storage uri.
"""
if bool(image_bytes) == bool(gcs_uri):
raise ValueError("Either image_bytes or gcs_uri must be provided.")

self._image_ = Image(image_bytes, gcs_uri)

@property
def image(self) -> Optional[Image]:
"""The scribble image."""
return self._image_


@dataclasses.dataclass
class EntityLabel:
"""Entity label holding a text label and any associated confidence score."""

__module__ = "vertexai.preview.vision_models"

label: Optional[str] = None
score: Optional[float] = None


class GeneratedMask(Image):
"""Generated image mask."""

__module__ = "vertexai.preview.vision_models"

__labels__: Optional[List[EntityLabel]] = None

def __init__(
self,
image_bytes: Optional[bytes],
gcs_uri: Optional[str] = None,
labels: Optional[List[EntityLabel]] = None,
):
"""Creates a `GeneratedMask` object.
Args:
image_bytes: Mask image file bytes.
gcs_uri: Mask image file Google Cloud Storage uri.
labels: Generated entity labels. Each text label might be associated
with a confidence score.
"""

super().__init__(
image_bytes=image_bytes,
gcs_uri=gcs_uri,
)
self.__labels__ = labels

@property
def labels(self) -> Optional[List[EntityLabel]]:
"""The entity labels of the masked object."""
return self.__labels__


@dataclasses.dataclass
class ImageSegmentationResponse:
"""Image Segmentation response.
Attributes:
masks: The list of generated masks.
"""

__module__ = "vertexai.preview.vision_models"

_prediction_response: Any
masks: List[GeneratedMask]

def __iter__(self) -> typing.Iterator[GeneratedMask]:
"""Iterates through the generated masks."""
yield from self.masks

def __getitem__(self, idx: int) -> GeneratedMask:
"""Gets the generated masks by index."""
return self.masks[idx]


class ImageSegmentationModel(_model_garden_models._ModelGardenModel):
"""Segments an image."""

__module__ = "vertexai.preview.vision_models"

_INSTANCE_SCHEMA_URI = "gs://google-cloud-aiplatform/schema/predict/instance/image_segmentation_model_1.0.0.yaml"

def segment_image(
self,
base_image: Image,
prompt: Optional[str] = None,
scribble: Optional[Scribble] = None,
mode: Literal[
"foreground", "background", "semantic", "prompt", "interactive"
] = "foreground",
max_predictions: Optional[int] = None,
confidence_threshold: Optional[float] = 0.1,
mask_dilation: Optional[float] = None,
) -> ImageSegmentationResponse:
"""Segments an image.
Args:
base_image: The base image to segment.
prompt: The prompt to guide the segmentation. Valid for the prompt and
semantic modes.
scribble: The scribble in the form of an image mask to guide the
segmentation. Valid for the interactive mode. The scribble image
should be a black-and-white PNG file equal in size to the base
image. White pixels represent the scribbled brush stroke which
select objects in the base image to segment.
mode: The segmentation mode. Supported values are:
* foreground: segment the foreground object of an image
* background: segment the background of an image
* semantic: specify the objects to segment with a comma delimited
list of objects from the class set in the prompt.
* prompt: use an open-vocabulary text prompt to select objects to
segment.
* interactive: draw scribbles with a brush stroke to guide the
segmentation. The default is foreground.
max_predictions: The maximum number of predictions to make. Valid for
the prompt mode. Default is unlimited.
confidence_threshold: A threshold to filter predictions by confidence
score. The value must be in the range of 0.0 and 1.0. The default is
0.1.
mask_dilation: A value to dilate the masks by. The value must be in the
range of 0.0 (no dilation) and 1.0 (the whole image will be masked).
The default is 0.0.
Returns:
An `ImageSegmentationResponse` object with the generated masks,
entities, and labels (if any).
"""
if not base_image:
raise ValueError("Base image is required.")
instance = {}

if base_image._gcs_uri:
instance["image"] = {"gcsUri": base_image._gcs_uri}
else:
instance["image"] = {"bytesBase64Encoded": base_image._as_base64_string()}

if prompt:
instance["prompt"] = prompt

parameters = {}
if scribble and scribble.image:
scribble_image = scribble.image
if scribble_image._gcs_uri:
instance["scribble"] = {"image": {"gcsUri": scribble_image._gcs_uri}}
else:
instance["scribble"] = {
"image": {"bytesBase64Encoded": scribble_image._as_base64_string()}
}
parameters["mode"] = mode
if max_predictions:
parameters["maxPredictions"] = max_predictions
if confidence_threshold:
parameters["confidenceThreshold"] = confidence_threshold
if mask_dilation:
parameters["maskDilation"] = mask_dilation

response = self._endpoint.predict(
instances=[instance],
parameters=parameters,
)

masks: List[GeneratedMask] = []
for prediction in response.predictions:
encoded_bytes = prediction.get("bytesBase64Encoded")
labels = []
if "labels" in prediction:
for label in prediction["labels"]:
labels.append(
EntityLabel(
label=label.get("label"),
score=label.get("score"),
)
)
generated_image = GeneratedMask(
image_bytes=base64.b64decode(encoded_bytes) if encoded_bytes else None,
gcs_uri=prediction.get("gcsUri"),
labels=labels,
)
masks.append(generated_image)

return ImageSegmentationResponse(
_prediction_response=response,
masks=masks,
)

0 comments on commit ae63a43

Please sign in to comment.