Skip to content

Commit

Permalink
fixing unit test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
kshitij-microsoft committed Sep 30, 2024
1 parent 3dbb2d2 commit 077cf3b
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,9 @@ def __init__(
self.managed_network = managed_network

@classmethod
def _from_rest_object(cls, rest_obj: RestWorkspace) -> Optional["FeatureStore"]:
def _from_rest_object(
cls, rest_obj: RestWorkspace, v2_service_context: Optional[object] = None
) -> Optional["FeatureStore"]:
if not rest_obj:
return None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def _get_schema_class(cls):
return HubSchema

@classmethod
def _from_rest_object(cls, rest_obj: RestWorkspace, v2_service_context: Optional[object]) -> Optional["Hub"]:
def _from_rest_object(cls, rest_obj: RestWorkspace, v2_service_context: Optional[object] = None) -> Optional["Hub"]:
if not rest_obj:
return None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,6 @@ def mlflow_tracking_uri(self) -> Optional[str]:
:return: Returns mlflow tracking uri of the workspace.
:rtype: str
"""
# TODO: To check with Amit the use of this function

return self._mlflow_tracking_uri

def dump(self, dest: Union[str, PathLike, IO[AnyStr]], **kwargs: Any) -> None:
Expand Down Expand Up @@ -317,7 +315,9 @@ def _load(
return result

@classmethod
def _from_rest_object(cls, rest_obj: RestWorkspace, v2_service_context: Optional[object] = None) -> Optional["Workspace"]:
def _from_rest_object(
cls, rest_obj: RestWorkspace, v2_service_context: Optional[object] = None
) -> Optional["Workspace"]:

if not rest_obj:
return None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,17 +87,18 @@ def get(self, workspace_name: Optional[str] = None, **kwargs: Any) -> Optional[W
resource_group = kwargs.get("resource_group") or self._resource_group_name
obj = self._operation.get(resource_group, workspace_name)
v2_service_context = {}

v2_service_context["subscription_id"] = self._subscription_id
v2_service_context["workspace_name"] = workspace_name
v2_service_context["resource_group_name"] = resource_group
v2_service_context["auth"] = self._credentials
v2_service_context["auth"] = self._credentials # type: ignore

from urllib.parse import urlparse

parsed_url = urlparse(obj.ml_flow_tracking_uri)
host_url = "https://{}".format(parsed_url.netloc)
v2_service_context['host_url'] = host_url
v2_service_context["host_url"] = host_url

# host_url=service_context._get_mlflow_url(),
# cloud=_get_cloud_or_default(
# service_context.get_auth()._cloud_type.name
Expand Down Expand Up @@ -436,7 +437,7 @@ def callback(_: Any, deserialized: Any, args: Any) -> Workspace:
return (
deserialize_callback(deserialized)
if deserialize_callback
else Workspace._from_rest_object(deserialized, None)
else Workspace._from_rest_object(deserialized)
)

real_callback = callback
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import Optional
from unittest.mock import ANY, DEFAULT, MagicMock, Mock
from unittest.mock import ANY, DEFAULT, MagicMock, Mock, patch
from uuid import UUID, uuid4

import pytest
Expand All @@ -20,13 +20,23 @@
)
from azure.ai.ml.operations import WorkspaceOperations
from azure.core.polling import LROPoller
import urllib.parse


@pytest.fixture
def mock_credential() -> Mock:
yield Mock()


def mock_urlparse(url: str) -> urllib.parse.ParseResult:
return urllib.parse.ParseResult(
scheme="http", netloc="example.com", path="/index.html", params="", query="a=1&b=2", fragment=""
)


urllib.parse.urlparse = mock_urlparse


@pytest.fixture
def mock_workspace_operation(
mock_workspace_scope: OperationScope,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,23 @@
)
from azure.ai.ml.operations._workspace_operations_base import WorkspaceOperationsBase
from azure.core.polling import LROPoller
import urllib.parse


@pytest.fixture
def mock_credential() -> Mock:
yield Mock()


def mock_urlparse(url: str) -> urllib.parse.ParseResult:
return urllib.parse.ParseResult(
scheme="http", netloc="example.com", path="/index.html", params="", query="a=1&b=2", fragment=""
)


urllib.parse.urlparse = mock_urlparse


@pytest.fixture
def mock_workspace_operation_base(
mock_workspace_scope: OperationScope,
Expand Down

0 comments on commit 077cf3b

Please sign in to comment.