Skip to content

Commit

Permalink
update constant args to handle camel case
Browse files Browse the repository at this point in the history
  • Loading branch information
Josh Harrington committed Feb 14, 2023
1 parent 5621119 commit 58365c4
Show file tree
Hide file tree
Showing 10 changed files with 13,020 additions and 4,796 deletions.
49 changes: 38 additions & 11 deletions sdk/ml/azure-ai-ml/azure/ai/ml/_schema/workspace/networking.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from marshmallow import fields
from marshmallow.decorators import post_load, pre_dump
from azure.ai.ml._schema.core.schema_meta import PatchedSchemaMeta
from azure.ai.ml._schema.core.fields import StringTransformedEnum
from azure.ai.ml._schema.core.fields import NestedField
from azure.ai.ml.entities._workspace.networking import (
ManagedNetwork,
Expand All @@ -15,6 +16,8 @@
ServiceTagDestination,
PrivateEndpointDestination,
)
from azure.ai.ml.constants._workspace import IsolationMode, OutboundRuleCategory, OutboundRuleType
from azure.ai.ml._utils.utils import camel_to_snake, _snake_to_camel

from azure.ai.ml._utils._experimental import experimental

Expand All @@ -31,35 +34,51 @@ class DestinationSchema(metaclass=PatchedSchemaMeta):

@experimental
class OutboundRuleSchema(metaclass=PatchedSchemaMeta):
type = fields.Str(required=False)
type = StringTransformedEnum(
allowed_values=[OutboundRuleType.FQDN, OutboundRuleType.PRIVATE_ENDPOINT, OutboundRuleType.SERVICE_TAG],
casing_transform=camel_to_snake,
metadata={"description": "outbound rule type."},
)
destination = fields.Raw(required=True)
category = fields.Str(required=False)
category = StringTransformedEnum(
allowed_values=[
OutboundRuleCategory.REQUIRED,
OutboundRuleCategory.RECOMMENDED,
OutboundRuleCategory.USER_DEFINED,
],
casing_transform=camel_to_snake,
metadata={"description": "outbound rule category."},
)

@pre_dump
def predump(self, data, **kwargs):
if data and isinstance(data, FqdnDestination):
data.destination = self.fqdn_dest2dict(data.destination)
if data and isinstance(data, PrivateEndpointDestination):
data.destination = self.pe_dest2dict(
data.service_resource_id, data.subresource_target, data.spark_enabled
)
data.destination = self.pe_dest2dict(data.service_resource_id, data.subresource_target, data.spark_enabled)
if data and isinstance(data, ServiceTagDestination):
data.destination = self.service_tag_dest2dict(data.service_tag, data.protocol, data.port_ranges)
return data

@post_load
def createdestobject(self, data, **kwargs):
dest = data.get("destination", False)
category = data.get("category", OutboundRuleCategory.USER_DEFINED)
if dest:
if isinstance(dest, str):
return FqdnDestination(dest)
return FqdnDestination(dest, _snake_to_camel(category))
else:
if dest.get("subresource_target", False):
return PrivateEndpointDestination(
dest["service_resource_id"], dest["subresource_target"], dest["spark_enabled"]
dest["service_resource_id"],
dest["subresource_target"],
dest["spark_enabled"],
_snake_to_camel(category),
)
if dest.get("service_tag", False):
return ServiceTagDestination(dest["service_tag"], dest["protocol"], dest["port_ranges"])
return ServiceTagDestination(
dest["service_tag"], dest["protocol"], dest["port_ranges"], _snake_to_camel(category)
)
return OutboundRule(data)

def fqdn_dest2dict(self, fqdndest):
Expand All @@ -83,7 +102,15 @@ def service_tag_dest2dict(self, service_tag, protocol, port_ranges):

@experimental
class ManagedNetworkSchema(metaclass=PatchedSchemaMeta):
isolation_mode = fields.Str(required=True)
isolation_mode = StringTransformedEnum(
allowed_values=[
IsolationMode.DISABLED,
IsolationMode.ALLOW_INTERNET_OUTBOUND,
IsolationMode.ALLOW_ONLY_APPROVED_OUTBOUND,
],
casing_transform=camel_to_snake,
metadata={"description": "isolation mode for the workspace managed network."},
)
outbound_rules = fields.Dict(
keys=fields.Str(required=True), values=NestedField(OutboundRuleSchema, allow_none=False), allow_none=True
)
Expand All @@ -92,6 +119,6 @@ class ManagedNetworkSchema(metaclass=PatchedSchemaMeta):
@post_load
def make(self, data, **kwargs):
if data.get("outbound_rules", False):
return ManagedNetwork(data["isolation_mode"], data["outbound_rules"])
return ManagedNetwork(_snake_to_camel(data["isolation_mode"]), data["outbound_rules"])
else:
return ManagedNetwork(data["isolation_mode"])
return ManagedNetwork(_snake_to_camel(data["isolation_mode"]))
24 changes: 15 additions & 9 deletions sdk/ml/azure-ai-ml/azure/ai/ml/entities/_workspace/networking.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@
ServiceTagOutboundRule as RestServiceTagOutboundRule,
ServiceTagOutboundRuleDestination as RestServiceTagOutboundRuleDestination,
)
from azure.ai.ml.constants._workspace import (
IsolationMode, OutboundRuleCategory, OutboundRuleType
)
from azure.ai.ml.constants._workspace import IsolationMode, OutboundRuleCategory, OutboundRuleType

from azure.ai.ml._utils._experimental import experimental

Expand Down Expand Up @@ -53,9 +51,9 @@ def _from_rest_object(cls, rest_obj: Any) -> "OutboundRule":

@experimental
class FqdnDestination(OutboundRule):
def __init__(self, destination: str) -> None:
def __init__(self, destination: str, category: str = OutboundRuleCategory.USER_DEFINED) -> None:
self.destination = destination
OutboundRule.__init__(self, type=OutboundRuleType.FQDN)
OutboundRule.__init__(self, type=OutboundRuleType.FQDN, category=category)

def _to_rest_object(self) -> RestFqdnOutboundRule:
return RestFqdnOutboundRule(type=self.type, category=self.category, destination=self.destination)
Expand All @@ -66,11 +64,17 @@ def _to_dict(self) -> Dict:

@experimental
class PrivateEndpointDestination(OutboundRule):
def __init__(self, service_resource_id: str, subresource_target: str, spark_enabled: bool = False) -> None:
def __init__(
self,
service_resource_id: str,
subresource_target: str,
spark_enabled: bool = False,
category: str = OutboundRuleCategory.USER_DEFINED,
) -> None:
self.service_resource_id = service_resource_id
self.subresource_target = subresource_target
self.spark_enabled = spark_enabled
OutboundRule.__init__(self, OutboundRuleType.PRIVATE_ENDPOINT)
OutboundRule.__init__(self, OutboundRuleType.PRIVATE_ENDPOINT, category=category)

def _to_rest_object(self) -> RestPrivateEndpointOutboundRule:
return RestPrivateEndpointOutboundRule(
Expand All @@ -97,11 +101,13 @@ def _to_dict(self) -> Dict:

@experimental
class ServiceTagDestination(OutboundRule):
def __init__(self, service_tag: str, protocol: str, port_ranges: str) -> None:
def __init__(
self, service_tag: str, protocol: str, port_ranges: str, category: str = OutboundRuleCategory.USER_DEFINED
) -> None:
self.service_tag = service_tag
self.protocol = protocol
self.port_ranges = port_ranges
OutboundRule.__init__(self, OutboundRuleType.SERVICE_TAG)
OutboundRule.__init__(self, OutboundRuleType.SERVICE_TAG, category=category)

def _to_rest_object(self) -> RestServiceTagOutboundRule:
return RestServiceTagOutboundRule(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from azure.core.credentials import TokenCredential
from azure.core.polling import LROPoller

from azure.ai.ml._utils.utils import _snake_to_camel

ops_logger = OpsLogger(__name__)
module_logger = ops_logger.module_logger

Expand Down Expand Up @@ -87,7 +89,8 @@ def set(self, resource_group: str, ws_name: str, outbound_rule_name: str, **kwar

networkDto.outbound_rules = {}

type = kwargs.get("type", None) # pylint: disable=redefined-builtin
type = _snake_to_camel(kwargs.get("type", None)) # pylint: disable=redefined-builtin
type = OutboundRuleType.FQDN if type in ["fqdn", "Fqdn"] else type
destination = kwargs.get("destination", None)
service_tag = kwargs.get("service_tag", None)
protocol = kwargs.get("protocol", None)
Expand Down
Loading

0 comments on commit 58365c4

Please sign in to comment.