Skip to content

Commit

Permalink
Add gax support for entity annotations.
Browse files Browse the repository at this point in the history
  • Loading branch information
daspecster committed Jan 5, 2017
1 parent 60f1ada commit c230cf6
Show file tree
Hide file tree
Showing 11 changed files with 321 additions and 5 deletions.
15 changes: 15 additions & 0 deletions system_tests/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ def _assert_face(self, face):

def test_detect_faces_content(self):
client = Config.CLIENT
client._use_gax = False
with open(FACE_FILE, 'rb') as image_file:
image = client.image(content=image_file.read())
faces = image.detect_faces()
Expand All @@ -208,6 +209,7 @@ def test_detect_faces_gcs(self):
source_uri = 'gs://%s/%s' % (bucket_name, blob_name)

client = Config.CLIENT
client._use_gax = False
image = client.image(source_uri=source_uri)
faces = image.detect_faces()
self.assertEqual(len(faces), 5)
Expand All @@ -216,6 +218,7 @@ def test_detect_faces_gcs(self):

def test_detect_faces_filename(self):
client = Config.CLIENT
client._use_gax = False
image = client.image(filename=FACE_FILE)
faces = image.detect_faces()
self.assertEqual(len(faces), 5)
Expand Down Expand Up @@ -310,6 +313,7 @@ def _assert_landmark(self, landmark):

def test_detect_landmark_content(self):
client = Config.CLIENT
client._use_gax = True
with open(LANDMARK_FILE, 'rb') as image_file:
image = client.image(content=image_file.read())
landmarks = image.detect_landmarks()
Expand All @@ -328,6 +332,7 @@ def test_detect_landmark_gcs(self):
source_uri = 'gs://%s/%s' % (bucket_name, blob_name)

client = Config.CLIENT
client._use_gax = True
image = client.image(source_uri=source_uri)
landmarks = image.detect_landmarks()
self.assertEqual(len(landmarks), 1)
Expand All @@ -336,6 +341,7 @@ def test_detect_landmark_gcs(self):

def test_detect_landmark_filename(self):
client = Config.CLIENT
client._use_gax = True
image = client.image(filename=LANDMARK_FILE)
landmarks = image.detect_landmarks()
self.assertEqual(len(landmarks), 1)
Expand All @@ -362,6 +368,7 @@ def _assert_safe_search(self, safe_search):

def test_detect_safe_search_content(self):
client = Config.CLIENT
client._use_gax = False
with open(FACE_FILE, 'rb') as image_file:
image = client.image(content=image_file.read())
safe_searches = image.detect_safe_search()
Expand All @@ -380,6 +387,7 @@ def test_detect_safe_search_gcs(self):
source_uri = 'gs://%s/%s' % (bucket_name, blob_name)

client = Config.CLIENT
client._use_gax = False
image = client.image(source_uri=source_uri)
safe_searches = image.detect_safe_search()
self.assertEqual(len(safe_searches), 1)
Expand All @@ -388,6 +396,7 @@ def test_detect_safe_search_gcs(self):

def test_detect_safe_search_filename(self):
client = Config.CLIENT
client._use_gax = False
image = client.image(filename=FACE_FILE)
safe_searches = image.detect_safe_search()
self.assertEqual(len(safe_searches), 1)
Expand Down Expand Up @@ -423,6 +432,7 @@ def _assert_text(self, text):

def test_detect_text_content(self):
client = Config.CLIENT
client._use_gax = True
with open(TEXT_FILE, 'rb') as image_file:
image = client.image(content=image_file.read())
texts = image.detect_text()
Expand All @@ -441,6 +451,7 @@ def test_detect_text_gcs(self):
source_uri = 'gs://%s/%s' % (bucket_name, blob_name)

client = Config.CLIENT
client._use_gax = True
image = client.image(source_uri=source_uri)
texts = image.detect_text()
self.assertEqual(len(texts), 9)
Expand All @@ -449,6 +460,7 @@ def test_detect_text_gcs(self):

def test_detect_text_filename(self):
client = Config.CLIENT
client._use_gax = True
image = client.image(filename=TEXT_FILE)
texts = image.detect_text()
self.assertEqual(len(texts), 9)
Expand Down Expand Up @@ -485,6 +497,7 @@ def _assert_properties(self, image_property):

def test_detect_properties_content(self):
client = Config.CLIENT
client._use_gax = False
with open(FACE_FILE, 'rb') as image_file:
image = client.image(content=image_file.read())
properties = image.detect_properties()
Expand All @@ -503,6 +516,7 @@ def test_detect_properties_gcs(self):
source_uri = 'gs://%s/%s' % (bucket_name, blob_name)

client = Config.CLIENT
client._use_gax = False
image = client.image(source_uri=source_uri)
properties = image.detect_properties()
self.assertEqual(len(properties), 1)
Expand All @@ -511,6 +525,7 @@ def test_detect_properties_gcs(self):

def test_detect_properties_filename(self):
client = Config.CLIENT
client._use_gax = False
image = client.image(filename=FACE_FILE)
properties = image.detect_properties()
self.assertEqual(len(properties), 1)
Expand Down
23 changes: 23 additions & 0 deletions vision/google/cloud/vision/_gax.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

from google.cloud._helpers import _to_bytes

from google.cloud.vision.annotations import Annotations


class _GAPICVisionAPI(object):
"""Vision API for interacting with the gRPC version of Vision.
Expand All @@ -30,6 +32,27 @@ def __init__(self, client=None):
self._client = client
self._api = image_annotator_client.ImageAnnotatorClient()

def annotate(self, image, features):
"""Annotate images through GAX.
:type image: :class:`~google.cloud.vision.image.Image`
:param image: Instance of ``Image``.
:type features: list
:param features: List of :class:`~google.cloud.vision.feature.Feature`.
:rtype: :class:`~google.cloud.vision.annotations.Annotations`
:returns: Instance of ``Annotations`` with results.
"""
gapic_features = [_to_gapic_feature(feature) for feature in features]
gapic_image = _to_gapic_image(image)
request = image_annotator_pb2.AnnotateImageRequest(
image=gapic_image, features=gapic_features)
requests = [request]
api = self._api
responses = api.batch_annotate_images(requests)
return Annotations.from_pb(responses.responses[0])


def _to_gapic_feature(feature):
"""Helper function to convert a ``Feature`` to a gRPC ``Feature``.
Expand Down
3 changes: 2 additions & 1 deletion vision/google/cloud/vision/_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

"""HTTP Client for interacting with the Google Cloud Vision API."""

from google.cloud.vision.annotations import Annotations
from google.cloud.vision.feature import Feature


Expand Down Expand Up @@ -49,7 +50,7 @@ def annotate(self, image, features):
api_response = self._connection.api_request(
method='POST', path='/images:annotate', data=data)
responses = api_response.get('responses')
return responses[0]
return Annotations.from_api_repr(responses[0])


def _make_request(image, features):
Expand Down
49 changes: 49 additions & 0 deletions vision/google/cloud/vision/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,55 @@ def from_api_repr(cls, response):
_entity_from_response_type(feature_type, annotation))
return cls(**annotations)

@classmethod
def from_pb(cls, response):
"""Factory: construct an instance of ``Annotations`` from gRPC response.
:type response: :class:`~google.cloud.grpc.vision.v1.\
image_annotator_pb2.AnnotateImageResponse`
:param response: ``AnnotateImageResponse`` from gRPC call.
:rtype: :class:`~google.cloud.vision.annotations.Annotations`
:returns: ``Annotations`` instance populated from gRPC response.
"""
annotations = _process_image_annotations(response)
return cls(**annotations)


def _process_image_annotations(image):
"""Helper for processing annotation types from gRPC responses.
:type image: :class:`~google.cloud.grpc.vision.v1.image_annotator_pb2.\
AnnotateImageResponse`
:param image: ``AnnotateImageResponse`` from gRPC response.
:rtype: dict
:returns: Dictionary populated with entities from response.
"""
annotations = {}
annotations['labels'] = _make_entity_from_pb(image.label_annotations)
annotations['landmarks'] = _make_entity_from_pb(image.landmark_annotations)
annotations['logos'] = _make_entity_from_pb(image.logo_annotations)
annotations['texts'] = _make_entity_from_pb(image.text_annotations)
return annotations


def _make_entity_from_pb(annotations):
"""Create an entity from a gRPC response.
:type annotations:
:class:`~google.cloud.grpc.vision.v1.image_annotator_pb2.EntityAnnotation`
:param annotations: gRPC instance of ``EntityAnnotation``.
:rtype: list
:returns: List of ``EntityAnnotation``.
"""

entities = []
for annotation in annotations:
entities.append(EntityAnnotation.from_pb(annotation))
return entities


def _entity_from_response_type(feature_type, results):
"""Convert a JSON result to an entity type based on the feature.
Expand Down
20 changes: 20 additions & 0 deletions vision/google/cloud/vision/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,26 @@ def from_api_repr(cls, response):

return cls(bounds, description, locale, locations, mid, score)

@classmethod
def from_pb(cls, response):
"""Factory: construct entity from Vision gRPC response.
:type response: :class:`~google.cloud.grpc.vision.v1.\
image_annotator_pb2.AnnotateImageResponse`
:param response: gRPC response from Vision API with entity data.
:rtype: :class:`~google.cloud.vision.entity.EntityAnnotation`
:returns: Instance of ``EntityAnnotation``.
"""
bounds = Bounds.from_pb(response.bounding_poly)
description = response.description
locale = response.locale
locations = [LocationInformation.from_pb(location)
for location in response.locations]
mid = response.mid
score = response.score
return cls(bounds, description, locale, locations, mid, score)

@property
def bounds(self):
"""Bounding polygon of detected image feature.
Expand Down
27 changes: 27 additions & 0 deletions vision/google/cloud/vision/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,20 @@ def from_api_repr(cls, response_vertices):
vertex in response_vertices.get('vertices', [])]
return cls(vertices)

@classmethod
def from_pb(cls, response_vertices):
"""Factory: construct BoundsBase instance from Vision gRPC response.
:type response_vertices: :class:`~google.cloud.grpc.vision.v1.\
geometry_pb2.BoundingPoly`
:param response_vertices: List of vertices.
:rtype: :class:`~google.cloud.vision.geometry.BoundsBase` or None
:returns: Instance of ``BoundsBase`` with populated verticies.
"""
return cls([Vertex(vertex.x, vertex.y)
for vertex in response_vertices.vertices])

@property
def vertices(self):
"""List of vertices.
Expand Down Expand Up @@ -87,6 +101,19 @@ def from_api_repr(cls, response):
longitude = response['latLng']['longitude']
return cls(latitude, longitude)

@classmethod
def from_pb(cls, response):
"""Factory: construct location information from Vision gRPC response.
:type response: :class:`~google.cloud.vision.v1.LocationInfo`
:param response: gRPC response of ``LocationInfo``.
:rtype: :class:`~google.cloud.vision.geometry.LocationInformation`
:returns: ``LocationInformation`` with populated latitude and
longitude.
"""
return cls(response.lat_lng.latitude, response.lat_lng.longitude)

@property
def latitude(self):
"""Latitude coordinate.
Expand Down
4 changes: 1 addition & 3 deletions vision/google/cloud/vision/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

from google.cloud._helpers import _to_bytes
from google.cloud._helpers import _bytes_to_unicode
from google.cloud.vision.annotations import Annotations
from google.cloud.vision.feature import Feature
from google.cloud.vision.feature import FeatureTypes

Expand Down Expand Up @@ -109,8 +108,7 @@ def _detect_annotation(self, features):
:class:`~google.cloud.vision.color.ImagePropertiesAnnotation`,
:class:`~google.cloud.vision.sage.SafeSearchAnnotation`,
"""
results = self.client._vision_api.annotate(self, features)
return Annotations.from_api_repr(results)
return self.client._vision_api.annotate(self, features)

def detect(self, features):
"""Detect multiple feature types.
Expand Down
20 changes: 20 additions & 0 deletions vision/unit_tests/test__gax.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,26 @@ def test_ctor(self):
api = self._make_one(client)
self.assertIs(api._client, client)

def test_annotation(self):
from google.cloud.vision.feature import Feature
from google.cloud.vision.feature import FeatureTypes
from google.cloud.vision.image import Image

client = mock.Mock()
feature = Feature(FeatureTypes.LABEL_DETECTION, 5)
image_content = b'abc 1 2 3'
image = Image(client, content=image_content)
api = self._make_one(client)

api._api = mock.Mock()
mock_response = mock.Mock(responses=['mock response data'])
api._api.batch_annotate_images.return_value = mock_response

with mock.patch('google.cloud.vision._gax.Annotations') as mock_anno:
api.annotate(image, [feature])
mock_anno.from_pb.assert_called_with('mock response data')
api._api.batch_annotate_images.assert_called()


class TestToGAPICFeature(unittest.TestCase):
def _call_fut(self, feature):
Expand Down
Loading

0 comments on commit c230cf6

Please sign in to comment.