Skip to content

Commit

Permalink
feat: add managed autoscaling config to sagemaker endpoint module (#72)
Browse files Browse the repository at this point in the history
* feat: add managed autoscaling config to sagemaker endpoint module

Signed-off-by: Anton Kukushkin <[email protected]>

* changelog

Signed-off-by: Anton Kukushkin <[email protected]>

---------

Signed-off-by: Anton Kukushkin <[email protected]>
  • Loading branch information
kukushking authored Apr 30, 2024
1 parent d8aadc0 commit 0f5a208
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 8 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### **Added**

- added managed autoscaling config to `sagemaker-endpoint` module
- added SSO support in `sagemaker-studio` module
- added VPC/subnets/sg config for multi-account project template to `sagemaker-templates-service-catalog` module
- added `sagemaker-custom-kernel` module
Expand Down
4 changes: 4 additions & 0 deletions examples/manifests/sagemaker-endpoints-modules.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,7 @@ parameters:
group: networking
name: networking
key: PrivateSubnetIds
- name: managed_instance_scaling
value: True
- name: scaling_max_instance_count
value: 10
3 changes: 3 additions & 0 deletions modules/sagemaker/sagemaker-endpoint/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ where endpoints are provisioned as part of automated Continuous Integration and
- `initial-instance-count`: Initial instance count. `1` by default.
- `initial-variant-weight`: Initial variant weight. `1` by default.
- `instance-type`: instance type. `ml.m4.xlarge` by default.
- `managed-instance-scaling`: whether to enable managed instance autoscaling. `False` by default.
- `scaling-min-instance-count`: minimum autoscaling instance count. `1` by default. Only considered if `managed-instance-scaling` is `True`.
- `scaling-max-instance-count` minimum autoscaling instance count. `10` by default. Only considered if `managed-instance-scaling` is `True`.

### Sample manifest declaration

Expand Down
8 changes: 8 additions & 0 deletions modules/sagemaker/sagemaker-endpoint/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def _param(name: str) -> str:
DEFAULT_INITIAL_INSTANCE_COUNT = 1
DEFAULT_INITIAL_VARIANT_WEIGHT = 1
DEFAULT_INSTANCE_TYPE = "ml.m4.xlarge"
DEFAULT_SCALING_MIN_INSTANCE_COUNT = 1
DEFAULT_SCALING_MAX_INSTANCE_COUNT = 10

environment = aws_cdk.Environment(
account=os.environ["CDK_DEFAULT_ACCOUNT"],
Expand All @@ -49,6 +51,9 @@ def _param(name: str) -> str:
initial_instance_count = int(os.getenv(_param("INITIAL_INSTANCE_COUNT"), DEFAULT_INITIAL_INSTANCE_COUNT))
initial_variant_weight = int(os.getenv(_param("INITIAL_VARIANT_WEIGHT"), DEFAULT_INITIAL_VARIANT_WEIGHT))
instance_type = os.getenv(_param("INSTANCE_TYPE"), DEFAULT_INSTANCE_TYPE)
managed_instance_scaling = bool(os.getenv(_param("MANAGED_INSTANCE_SCALING"), False))
scaling_min_instance_count = int(os.getenv(_param("SCALING_MIN_INSTANCE_COUNT"), DEFAULT_SCALING_MIN_INSTANCE_COUNT))
scaling_max_instance_count = int(os.getenv(_param("SCALING_MAX_INSTANCE_COUNT"), DEFAULT_SCALING_MAX_INSTANCE_COUNT))

if not vpc_id:
raise ValueError("Missing input parameter vpc-id")
Expand Down Expand Up @@ -76,6 +81,9 @@ def _param(name: str) -> str:
"instance_type": instance_type,
"variant_name": variant_name,
},
managed_instance_scaling=managed_instance_scaling,
scaling_min_instance_count=scaling_min_instance_count,
scaling_max_instance_count=scaling_max_instance_count,
env=environment,
)

Expand Down
9 changes: 9 additions & 0 deletions modules/sagemaker/sagemaker-endpoint/stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ def __init__(
model_artifacts_bucket_arn: Optional[str],
ecr_repo_arn: Optional[str],
endpoint_config_prod_variant: Dict[str, Any],
managed_instance_scaling: bool,
scaling_min_instance_count: int,
scaling_max_instance_count: int,
**kwargs: Any,
) -> None:
super().__init__(scope, id, **kwargs)
Expand Down Expand Up @@ -171,6 +174,12 @@ def __init__(
sagemaker.CfnEndpointConfig.ProductionVariantProperty(
model_name=model_name,
**endpoint_config_prod_variant,
managed_instance_scaling=sagemaker.CfnEndpointConfig.ManagedInstanceScalingProperty(
max_instance_count=scaling_max_instance_count,
min_instance_count=scaling_min_instance_count,
)
if managed_instance_scaling
else None,
)
],
)
Expand Down
28 changes: 20 additions & 8 deletions modules/sagemaker/sagemaker-endpoint/tests/test_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ def stack_model_package_input() -> cdk.Stack:
vpc_id = "vpc-12345"
model_package_arn = "example-arn"
model_artifacts_bucket_arn = "arn:aws:s3:::test-bucket"
managed_instance_scaling = True
scaling_min_instance_count = 1
scaling_max_instance_count = 2

return stack.DeployEndpointStack(
scope=app,
Expand All @@ -48,14 +51,17 @@ def stack_model_package_input() -> cdk.Stack:
subnet_ids=[],
model_artifacts_bucket_arn=model_artifacts_bucket_arn,
ecr_repo_arn=None,
env=cdk.Environment(
account=os.environ["CDK_DEFAULT_ACCOUNT"],
region=os.environ["CDK_DEFAULT_REGION"],
),
endpoint_config_prod_variant={
"initial_variant_weight": 1,
"variant_name": "AllTraffic",
},
managed_instance_scaling=managed_instance_scaling,
scaling_min_instance_count=scaling_min_instance_count,
scaling_max_instance_count=scaling_max_instance_count,
env=cdk.Environment(
account=os.environ["CDK_DEFAULT_ACCOUNT"],
region=os.environ["CDK_DEFAULT_REGION"],
),
)


Expand All @@ -75,6 +81,9 @@ def stack_latest_approved_model_package(mock_s3_client) -> cdk.Stack:
vpc_id = "vpc-12345"
model_package_group_name = "example-group"
model_artifacts_bucket_arn = "arn:aws:s3:::test-bucket"
managed_instance_scaling = True
scaling_min_instance_count = 1
scaling_max_instance_count = 2

sagemaker_client = botocore.session.get_session().create_client("sagemaker", region_name="us-east-1")
mock_s3_client.return_value = sagemaker_client
Expand Down Expand Up @@ -110,14 +119,17 @@ def stack_latest_approved_model_package(mock_s3_client) -> cdk.Stack:
subnet_ids=[],
model_artifacts_bucket_arn=model_artifacts_bucket_arn,
ecr_repo_arn=None,
env=cdk.Environment(
account=os.environ["CDK_DEFAULT_ACCOUNT"],
region=os.environ["CDK_DEFAULT_REGION"],
),
endpoint_config_prod_variant={
"initial_variant_weight": 1,
"variant_name": "AllTraffic",
},
managed_instance_scaling=managed_instance_scaling,
scaling_min_instance_count=scaling_min_instance_count,
scaling_max_instance_count=scaling_max_instance_count,
env=cdk.Environment(
account=os.environ["CDK_DEFAULT_ACCOUNT"],
region=os.environ["CDK_DEFAULT_REGION"],
),
)


Expand Down

0 comments on commit 0f5a208

Please sign in to comment.