Skip to content

Commit

Permalink
fix: avoid OSError in LocalPathField deserialization
Browse files Browse the repository at this point in the history
  • Loading branch information
elliotzh committed Feb 14, 2023
1 parent d12ca3a commit 6108ce7
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 12 deletions.
24 changes: 12 additions & 12 deletions sdk/ml/azure-ai-ml/azure/ai/ml/_schema/core/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,14 +141,20 @@ def _resolve_path(self, value) -> Path:
"""Resolve path to absolute path based on base_path in context.
Will resolve the path if it's already an absolute path.
"""
result = Path(value)
base_path = Path(self.context[BASE_PATH_CONTEXT_KEY])
if not result.is_absolute():
result = base_path / result
try:
return result.resolve()
result = Path(value)
base_path = Path(self.context[BASE_PATH_CONTEXT_KEY])
if not result.is_absolute():
result = base_path / result

result = result.resolve()
if (self._allow_dir and result.is_dir()) or (self._allow_file and result.is_file()):
return result
except OSError:
raise self.make_error("invalid_path")
raise self.make_error(
"path_not_exist", path=result.as_posix(), allow_type=self.allowed_path_type
)

@property
def allowed_path_type(self) -> str:
Expand All @@ -164,18 +170,12 @@ def _validate(self, value):

if value is None:
return
path = self._resolve_path(value)
if (self._allow_dir and path.is_dir()) or (self._allow_file and path.is_file()):
return
raise self.make_error(
"path_not_exist", path=path.as_posix(), allow_type=self.allowed_path_type
)
self._resolve_path(value)

def _serialize(self, value, attr, obj, **kwargs) -> typing.Optional[str]:
# do not block serializing None even if required or not allow_none.
if value is None:
return None
self._validate(value)
# always dump path as absolute path in string as base_path will be dropped after serialization
return super(LocalPathField, self)._serialize(
self._resolve_path(value).as_posix(), attr, obj, **kwargs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,15 @@ def test_spark_component_version_as_a_function_with_inputs(self):
rest_yaml_component = yaml_component._to_rest_object()

assert rest_yaml_component == expected_rest_component

def test_from_rest_object(self):
invalid_code = "azureml:/subscriptions/xxx/resourceGroups/xxx/providers/Microsoft.MachineLearningServices/" \
"workspaces/zzz/codes/90b33c11-365d-4ee4-aaa1-224a042deb41/versions/1"
yaml_path = "./tests/test_configs/dsl_pipeline/spark_job_in_pipeline/add_greeting_column_component.yml"
yaml_component = load_component(yaml_path)

from azure.ai.ml.entities import Component
rest_object = yaml_component._to_rest_object()
rest_object.properties.component_spec["code"] = invalid_code
component = Component._from_rest_object(rest_object)
assert component.code == invalid_code[8:]

0 comments on commit 6108ce7

Please sign in to comment.