Skip to content

Commit

Permalink
fix: use workspace-id to distinguish on-disk cache (Azure#30744)
Browse files Browse the repository at this point in the history
  • Loading branch information
elliotzh authored Jun 14, 2023
1 parent 2ac597f commit fd2bc08
Show file tree
Hide file tree
Showing 7 changed files with 419 additions and 69 deletions.
25 changes: 4 additions & 21 deletions sdk/ml/azure-ai-ml/azure/ai/ml/_utils/_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,35 +95,18 @@ class CachedNodeResolver(object):
def __init__(
self,
resolver: Callable[[Union[Component, str]], str],
subscription_id: Optional[str],
resource_group_name: Optional[str],
workspace_name: Optional[str],
registry_name: Optional[str],
client_key: str,
):
self._resolver = resolver
self._cache: Dict[str, _CacheContent] = {}
self._nodes_to_resolve: List[BaseNode] = []

self._client_hash = self._get_client_hash(subscription_id, resource_group_name, workspace_name, registry_name)
hash_obj = hashlib.sha256()
hash_obj.update(client_key.encode("utf-8"))
self._client_hash = hash_obj.hexdigest()
# the same client share 1 lock
self._lock = _node_resolution_lock[self._client_hash]

@staticmethod
def _get_client_hash(
subscription_id: Optional[str],
resource_group_name: Optional[str],
workspace_name: Optional[str],
registry_name: Optional[str],
) -> str:
"""Get a hash for used client.
Works for both workspace client and registry client.
"""
object_hash = hashlib.sha256()
for s in [subscription_id, resource_group_name, workspace_name, registry_name]:
object_hash.update(str(s).encode("utf-8"))
return object_hash.hexdigest()

@staticmethod
def _get_component_registration_max_workers():
"""Get the max workers for component registration.
Expand Down
47 changes: 42 additions & 5 deletions sdk/ml/azure-ai-ml/azure/ai/ml/operations/_component_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from inspect import Parameter, signature
from typing import Any, Callable, Dict, Iterable, List, Optional, Union

from azure.core.exceptions import ResourceNotFoundError
from azure.core.exceptions import HttpResponseError, ResourceNotFoundError

from azure.ai.ml._restclient.v2021_10_01_dataplanepreview import (
AzureMachineLearningWorkspaces as ServiceClient102021Dataplane,
Expand Down Expand Up @@ -92,6 +92,8 @@ def __init__(
self._managed_label_resolver = {"latest": self._get_latest_version}
self._orchestrators = OperationOrchestrator(self._all_operations, self._operation_scope, self._operation_config)

self._client_key: Optional[str] = None

@property
def _code_operations(self) -> CodeOperations:
return self._all_operations.get_operation(AzureMLResourceType.CODE, lambda x: isinstance(x, CodeOperations))
Expand Down Expand Up @@ -756,6 +758,44 @@ def _divide_nodes_to_resolve_into_layers(cls, component: PipelineComponent, extr

return layers

def _get_workspace_key(self) -> str:
try:
workspace_rest = self._workspace_operations._operation.get(
resource_group_name=self._resource_group_name, workspace_name=self._workspace_name
)
return workspace_rest.workspace_id
except HttpResponseError:
return "{}/{}/{}".format(self._subscription_id, self._resource_group_name, self._workspace_name)

def _get_registry_key(self) -> str:
"""Get key for used registry.
Note that, although registry id is in registry discovery response, it is not in RegistryDiscoveryDto; and we'll
lose the information after deserialization.
To avoid changing related rest client, we simply use registry related information from self to construct
registry key, which means that on-disk cache will be invalid if a registry is deleted and then created
again with the same name.
"""
return "{}/{}/{}".format(self._subscription_id, self._resource_group_name, self._registry_name)

def _get_client_key(self) -> str:
"""Get key for used client.
Key should be able to uniquely identify used registry or workspace.
"""
# check cache first
if self._client_key:
return self._client_key

# registry name has a higher priority comparing to workspace name according to current __init__ implementation
# of MLClient
if self._registry_name:
self._client_key = "registry/" + self._get_registry_key()
elif self._workspace_name:
self._client_key = "workspace/" + self._get_workspace_key()
else:
# This should never happen.
raise ValueError("Either workspace name or registry name must be provided to use component operations.")
return self._client_key

def _resolve_dependencies_for_pipeline_component_jobs(
self, component: Union[Component, str], resolver: Callable, *, resolve_inputs: bool = True
):
Expand Down Expand Up @@ -797,10 +837,7 @@ def _resolve_dependencies_for_pipeline_component_jobs(
# relatively simple and of small number of distinct instances
component_cache = CachedNodeResolver(
resolver=resolver,
subscription_id=self._subscription_id,
resource_group_name=self._resource_group_name,
workspace_name=self._workspace_name,
registry_name=self._registry_name,
client_key=self._get_client_key(),
)

for layer in reversed(layers):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import pytest
from devtools_testutils import AzureRecordedTestCase

from azure.ai.ml import MLClient

from .._util import _COMPONENT_TIMEOUT_SECOND


@pytest.mark.e2etest
@pytest.mark.timeout(_COMPONENT_TIMEOUT_SECOND)
@pytest.mark.usefixtures("recorded_test")
@pytest.mark.pipeline_test
class TestComponentWithoutMock(AzureRecordedTestCase):
"""Do not use component related mock here."""

def test_get_client_key(
self, client: MLClient, registry_client: MLClient, pipelines_registry_client: MLClient
) -> None:
"""
Test private interface to get client key.
If you need to change the private interfaces and this test, please also update related code in
mock_component_hash.
"""
workspace_key = client.components._get_workspace_key()
assert workspace_key
assert "workspace/" + workspace_key == client.components._get_client_key()

registry_key1 = registry_client.components._get_registry_key()
registry_key2 = pipelines_registry_client.components._get_registry_key()
assert registry_key1
assert registry_key2
assert "registry/" + registry_key1 == registry_client.components._get_client_key()
assert "registry/" + registry_key2 == pipelines_registry_client.components._get_client_key()
assert registry_key1 != registry_key2
52 changes: 17 additions & 35 deletions sdk/ml/azure-ai-ml/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import base64
import hashlib
import json
import os
import random
Expand All @@ -9,11 +8,10 @@
import uuid
from collections import namedtuple
from datetime import datetime
from functools import partial
from importlib import reload
from os import getenv
from pathlib import Path
from typing import Callable, Optional, Tuple, Union
from typing import Callable, Tuple, Union
from unittest.mock import Mock, patch

import pytest
Expand Down Expand Up @@ -664,26 +662,6 @@ def generate_component_hash(*args, **kwargs):
return dict_hash


def get_client_hash_with_request_node_name(
subscription_id: Optional[str],
resource_group_name: Optional[str],
workspace_name: Optional[str],
registry_name: Optional[str],
random_seed: str,
):
"""Generate a hash for the client."""
object_hash = hashlib.sha256()
for s in [
subscription_id,
resource_group_name,
workspace_name,
registry_name,
random_seed,
]:
object_hash.update(str(s).encode("utf-8"))
return object_hash.hexdigest()


def clear_on_disk_cache(cached_resolver):
"""Clear on disk cache for current client."""
cached_resolver._lock.acquire()
Expand Down Expand Up @@ -724,26 +702,30 @@ def mock_component_hash(mocker: MockFixture, request: FixtureRequest):
# and test2 in workspace B, the version in recordings can be different.
# So we use a random (probably unique) on-disk cache base directory for each test, and on-disk cache operations
# will be thread-safe when concurrently running different tests.
mocker.patch(
"azure.ai.ml._utils._cache_utils.CachedNodeResolver._get_client_hash",
side_effect=partial(get_client_hash_with_request_node_name, random_seed=uuid.uuid4().hex),
)
involved_client_keys = set()
if not is_live_and_not_recording():
# Get client id will involve a new request to server, which is specifically tested in some tests.
# We mock it in playback mode to avoid changing recordings for most tests.
mock_workspace_id, mock_registry_id = uuid.uuid4().hex, uuid.uuid4().hex
mocker.patch(
"azure.ai.ml.operations._component_operations.ComponentOperations._get_workspace_key",
return_value=mock_workspace_id,
)
mocker.patch(
"azure.ai.ml.operations._component_operations.ComponentOperations._get_registry_key",
return_value=mock_registry_id,
)
involved_client_keys = {mock_workspace_id, mock_registry_id}

# Collect involved resolvers before yield, as fixtures may be destroyed after yield.
from azure.ai.ml._utils._cache_utils import CachedNodeResolver

involved_resolvers = []
for client_fixture_name in ["client", "registry_client"]:
if client_fixture_name not in request.fixturenames:
continue
client: MLClient = request.getfixturevalue(client_fixture_name)
for client_key in involved_client_keys:
involved_resolvers.append(
CachedNodeResolver(
resolver=None,
subscription_id=client.subscription_id,
resource_group_name=client.resource_group_name,
workspace_name=client.workspace_name,
registry_name=client._operation_scope.registry_name,
client_key=client_key,
)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import mock
import pytest

from azure.ai.ml import MLClient, load_job
from azure.ai.ml._utils._cache_utils import CachedNodeResolver
from azure.ai.ml.entities import Component, PipelineJob
Expand All @@ -30,10 +31,7 @@ def _get_cache_path(component: Component, resolver: CachedNodeResolver) -> Path:
def create_resolver(client: MLClient) -> CachedNodeResolver:
return CachedNodeResolver(
resolver=TestCacheUtils._mock_resolver,
subscription_id=client.subscription_id,
resource_group_name=client.resource_group_name,
workspace_name=client.workspace_name,
registry_name=client._operation_scope.registry_name,
client_key=client.components._get_client_key(),
)

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@
import pytest
import vcr
import yaml
from azure.ai.ml._azure_environments import _get_aml_resource_id_from_metadata, _resource_to_scopes
from azure.ai.ml.exceptions import ValidationException
from azure.core.credentials import AccessToken
from azure.core.exceptions import HttpResponseError
from azure.identity import DefaultAzureCredential
from msrest import Deserializer
from pytest_mock import MockFixture

from azure.ai.ml import MLClient, load_job
from azure.ai.ml._azure_environments import _get_aml_resource_id_from_metadata, _resource_to_scopes
from azure.ai.ml._restclient.v2023_04_01_preview import models
from azure.ai.ml._scope_dependent_operations import OperationConfig, OperationScope
from azure.ai.ml.constants._common import AZUREML_PRIVATE_FEATURES_ENV_VAR, AzureMLResourceType
Expand All @@ -22,13 +23,12 @@
from azure.ai.ml.entities._job.automl.training_settings import TrainingSettings
from azure.ai.ml.entities._job.job import Job
from azure.ai.ml.entities._job.sweep.sweep_job import SweepJob
from azure.ai.ml.exceptions import ValidationException
from azure.ai.ml.operations import DatastoreOperations, EnvironmentOperations, JobOperations, WorkspaceOperations
from azure.ai.ml.operations._code_operations import CodeOperations
from azure.ai.ml.operations._job_ops_helper import get_git_properties
from azure.ai.ml.operations._run_history_constants import RunHistoryConstants
from azure.ai.ml.operations._run_operations import RunOperations
from azure.core.exceptions import HttpResponseError
from azure.identity import DefaultAzureCredential

from .test_vcr_utils import before_record_cb, vcr_header_filters

Expand Down Expand Up @@ -147,6 +147,8 @@ def test_get(self, mock_method, mock_job_operation: JobOperations) -> None:
mock_job_operation.get("randon_name")
mock_job_operation._operation_2023_02_preview.get.assert_called_once()

# use mock_component_hash to avoid passing a Mock object as client key
@pytest.mark.usefixtures("mock_component_hash")
@patch.object(JobOperations, "_get_job")
def test_get_job(self, mock_method, mock_job_operation: JobOperations) -> None:
from azure.ai.ml import Input, dsl, load_component
Expand Down
Loading

0 comments on commit fd2bc08

Please sign in to comment.