Skip to content

Commit

Permalink
Model package validate config (Azure#33399)
Browse files Browse the repository at this point in the history
* Validate mode for model package

* e2e

* e2e tests

* pylint

* fix
  • Loading branch information
nemanjarajic authored Dec 6, 2023
1 parent a3aa35e commit 091f943
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 5 deletions.
2 changes: 1 addition & 1 deletion sdk/ml/azure-ai-ml/assets.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
"AssetsRepo": "Azure/azure-sdk-assets",
"AssetsRepoPrefixPath": "python",
"TagPrefix": "python/ml/azure-ai-ml",
"Tag": "python/ml/azure-ai-ml_d2948b931b"
"Tag": "python/ml/azure-ai-ml_6563ba1da3"
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

from azure.ai.ml._restclient.v2023_08_01_preview.models import ModelConfiguration as RestModelConfiguration
from azure.ai.ml._utils._experimental import experimental
from azure.ai.ml._utils.utils import snake_to_camel
from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException
from azure.ai.ml._exception_helper import log_and_raise_error


@experimental
Expand Down Expand Up @@ -36,4 +37,17 @@ def _from_rest_object(cls, rest_obj: RestModelConfiguration) -> "ModelConfigurat
return ModelConfiguration(mode=rest_obj.mode, mount_path=rest_obj.mount_path)

def _to_rest_object(self) -> RestModelConfiguration:
return RestModelConfiguration(mode=snake_to_camel(self.mode), mount_path=self.mount_path)
self.validate()
return RestModelConfiguration(mode=self.mode, mount_path=self.mount_path)

def validate(self):
if self.mode.lower() not in ["copy", "download"]:
msg = "Mode must be either 'Copy' or 'Download'"
err = ValidationException(
message=msg,
target=ErrorTarget.MODEL,
no_personal_data_message=msg,
error_category=ErrorCategory.USER_ERROR,
error_type=ValidationErrorType.INVALID_VALUE,
)
log_and_raise_error(err)
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_model_package_workspace(self, client: MLClient):
package_config = ModelPackage(
target_environment="my-package-name",
inferencing_server=AzureMLOnlineInferencingServer(),
model_configuration=ModelConfiguration(mode="copy"),
model_configuration=ModelConfiguration(mode="Copy"),
)

client.models.package("test-model2", "1", package_config)
client.models.package("test-model-1", "1", package_config)

0 comments on commit 091f943

Please sign in to comment.