Skip to content

Commit

Permalink
Ran linter
Browse files Browse the repository at this point in the history
  • Loading branch information
ivanmkc committed May 4, 2022
1 parent 528a915 commit c402b50
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 27 deletions.
16 changes: 12 additions & 4 deletions google/cloud/aiplatform/matching_engine/matching_engine_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,11 +264,15 @@ def update_metadata(
)

_LOGGER.log_action_start_against_resource(
"Updating", "index", self,
"Updating",
"index",
self,
)

update_lro = self.api_client.update_index(
index=gapic_index, update_mask=update_mask, metadata=request_metadata,
index=gapic_index,
update_mask=update_mask,
metadata=request_metadata,
)

_LOGGER.log_action_started_against_resource_with_lro(
Expand Down Expand Up @@ -326,11 +330,15 @@ def update_embeddings(
)

_LOGGER.log_action_start_against_resource(
"Updating", "index", self,
"Updating",
"index",
self,
)

update_lro = self.api_client.update_index(
index=gapic_index, update_mask=update_mask, metadata=request_metadata,
index=gapic_index,
update_mask=update_mask,
metadata=request_metadata,
)

_LOGGER.log_action_started_against_resource_with_lro(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,9 @@ def create(
"""
gapic_index_endpoint = gca_matching_engine_index_endpoint.IndexEndpoint(
display_name=display_name, description=description, network=network,
display_name=display_name,
description=description,
network=network,
)

if labels:
Expand Down Expand Up @@ -411,16 +413,20 @@ def _build_deployed_index(
machine_type=machine_type
)

deployed_index.dedicated_resources = gca_machine_resources_compat.DedicatedResources(
machine_spec=machine_spec,
min_replica_count=min_replica_count,
max_replica_count=max_replica_count,
deployed_index.dedicated_resources = (
gca_machine_resources_compat.DedicatedResources(
machine_spec=machine_spec,
min_replica_count=min_replica_count,
max_replica_count=max_replica_count,
)
)

else:
deployed_index.automatic_resources = gca_machine_resources_compat.AutomaticResources(
min_replica_count=min_replica_count,
max_replica_count=max_replica_count,
deployed_index.automatic_resources = (
gca_machine_resources_compat.AutomaticResources(
min_replica_count=min_replica_count,
max_replica_count=max_replica_count,
)
)
return deployed_index

Expand Down Expand Up @@ -538,7 +544,9 @@ def deploy_index(
self.wait()

_LOGGER.log_action_start_against_resource(
"Deploying index", "index_endpoint", self,
"Deploying index",
"index_endpoint",
self,
)

deployed_index = self._build_deployed_index(
Expand Down Expand Up @@ -596,7 +604,9 @@ def undeploy_index(
self.wait()

_LOGGER.log_action_start_against_resource(
"Undeploying index", "index_endpoint", self,
"Undeploying index",
"index_endpoint",
self,
)

undeploy_lro = self.api_client.undeploy_index(
Expand Down Expand Up @@ -659,7 +669,9 @@ def mutate_deployed_index(
self.wait()

_LOGGER.log_action_start_against_resource(
"Mutating index", "index_endpoint", self,
"Mutating index",
"index_endpoint",
self,
)

deployed_index = self._build_deployed_index(
Expand Down
3 changes: 2 additions & 1 deletion tests/system/aiplatform/test_matching_engine_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,8 @@ class TestMatchingEngine(e2e_base.TestEndToEnd):

def test_create_get_list_matching_engine_index(self, shared_state):
aiplatform.init(
project=e2e_base._PROJECT, location=e2e_base._LOCATION,
project=e2e_base._PROJECT,
location=e2e_base._LOCATION,
)

project_number = get_project_number(e2e_base._PROJECT)
Expand Down
13 changes: 9 additions & 4 deletions tests/unit/aiplatform/test_matching_engine_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,9 @@ def update_index_embeddings_mock():
index_service_client.IndexServiceClient, "update_index"
) as update_index_mock:
index_lro_mock = mock.Mock(operation.Operation)
index_lro_mock.result.return_value = gca_index.Index(name=_TEST_INDEX_NAME,)
index_lro_mock.result.return_value = gca_index.Index(
name=_TEST_INDEX_NAME,
)
update_index_mock.return_value = index_lro_mock
yield update_index_mock

Expand Down Expand Up @@ -307,7 +309,9 @@ def test_create_tree_ah_index(self, create_index_mock, sync):
)

create_index_mock.assert_called_once_with(
parent=_TEST_PARENT, index=expected, metadata=_TEST_REQUEST_METADATA,
parent=_TEST_PARENT,
index=expected,
metadata=_TEST_REQUEST_METADATA,
)

@pytest.mark.usefixtures("get_index_mock")
Expand Down Expand Up @@ -346,6 +350,7 @@ def test_create_brute_force_index(self, create_index_mock, sync):
)

create_index_mock.assert_called_once_with(
parent=_TEST_PARENT, index=expected, metadata=_TEST_REQUEST_METADATA,
parent=_TEST_PARENT,
index=expected,
metadata=_TEST_REQUEST_METADATA,
)

19 changes: 12 additions & 7 deletions tests/unit/aiplatform/test_matching_engine_index_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,8 @@ def get_index_mock():

index.deployed_indexes = [
gca_matching_engine_deployed_index_ref.DeployedIndexRef(
index_endpoint=index.name, deployed_index_id=_TEST_DEPLOYED_INDEX_ID,
index_endpoint=index.name,
deployed_index_id=_TEST_DEPLOYED_INDEX_ID,
)
]

Expand Down Expand Up @@ -298,7 +299,8 @@ def get_index_endpoint_mock():
@pytest.fixture
def deploy_index_mock():
with patch.object(
index_endpoint_service_client.IndexEndpointServiceClient, "deploy_index",
index_endpoint_service_client.IndexEndpointServiceClient,
"deploy_index",
) as deploy_index_mock:
deploy_index_lro_mock = mock.Mock(operation.Operation)
deploy_index_mock.return_value = deploy_index_lro_mock
Expand All @@ -308,7 +310,8 @@ def deploy_index_mock():
@pytest.fixture
def undeploy_index_mock():
with patch.object(
index_endpoint_service_client.IndexEndpointServiceClient, "undeploy_index",
index_endpoint_service_client.IndexEndpointServiceClient,
"undeploy_index",
) as undeploy_index_mock:
undeploy_index_lro_mock = mock.Mock(operation.Operation)
undeploy_index_mock.return_value = undeploy_index_lro_mock
Expand Down Expand Up @@ -369,10 +372,12 @@ def create_index_endpoint_mock():
"create_index_endpoint",
) as create_index_endpoint_mock:
create_index_endpoint_lro_mock = mock.Mock(operation.Operation)
create_index_endpoint_lro_mock.result.return_value = gca_index_endpoint.IndexEndpoint(
name=_TEST_INDEX_ENDPOINT_NAME,
display_name=_TEST_INDEX_ENDPOINT_DISPLAY_NAME,
description=_TEST_INDEX_ENDPOINT_DESCRIPTION,
create_index_endpoint_lro_mock.result.return_value = (
gca_index_endpoint.IndexEndpoint(
name=_TEST_INDEX_ENDPOINT_NAME,
display_name=_TEST_INDEX_ENDPOINT_DISPLAY_NAME,
description=_TEST_INDEX_ENDPOINT_DESCRIPTION,
)
)
create_index_endpoint_mock.return_value = create_index_endpoint_lro_mock
yield create_index_endpoint_mock
Expand Down

0 comments on commit c402b50

Please sign in to comment.