From b2fd14d0a7129576fc849d3391163f0e0836251c Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 20 Jun 2024 14:11:29 -0400 Subject: [PATCH] Update the SageMaker Model Monitoring module with model bias monitoring. (#126) Co-authored-by: Patrick Cloke --- CHANGELOG.md | 2 +- README.md | 2 +- .../sagemaker-model-monitoring/README.md | 21 +++- .../model_bias_construct.py | 117 ++++++++++++++++++ .../sagemaker_model_monitoring/settings.py | 16 +++ .../sagemaker_model_monitoring/stack.py | 49 +++++++- .../tests/test_app.py | 5 +- .../tests/test_stack.py | 15 ++- 8 files changed, 214 insertions(+), 13 deletions(-) create mode 100644 modules/sagemaker/sagemaker-model-monitoring/sagemaker_model_monitoring/model_bias_construct.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 49d4c743..b70e2962 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,7 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### **Added** - added `ray-on-eks`, and `manifests/ray-on-eks` manifests -- Added a `sagemaker-model-monitoring-module` module with an example of data quality and model quality monitoring of a SageMaker Endpoint. +- Added a `sagemaker-model-monitoring-module` module with an example of data quality, model quality, and model bias monitoring of a SageMaker Endpoint. - Added an option to enable data capture in the `sagemaker-endpoint-module`. ### **Changed** diff --git a/README.md b/README.md index 4ca2eaf0..c8a8b810 100644 --- a/README.md +++ b/README.md @@ -32,7 +32,7 @@ See deployment steps in the [Deployment Guide](DEPLOYMENT.md). | [SageMaker Custom Kernel Module](modules/sagemaker/sagemaker-custom-kernel/README.md) | Builds custom kernel for SageMaker Studio from a Dockerfile | | [SageMaker Model Package Group Module](modules/sagemaker/sagemaker-model-package-group/README.md) | Creates a SageMaker Model Package Group to register and version SageMaker Machine Learning (ML) models and setups an Amazon EventBridge Rule to send model package group state change events to an Amazon EventBridge Bus | | [SageMaker Model Package Promote Pipeline Module](modules/sagemaker/sagemaker-model-package-promote-pipeline/README.md) | Deploy a Pipeline to promote SageMaker Model Packages in a multi-account setup. The pipeline can be triggered through an EventBridge rule in reaction of a SageMaker Model Package Group state event change (Approved/Rejected). Once the pipeline is triggered, it will promote the latest approved model package, if one is found. | -| [SageMaker Model Monitoring Module](modules/sagemaker/sagemaker-model-monitoring-module/README.md) | Deploy data quality & model quality monitoring jobs which run against a SageMaker Endpoint. | +| [SageMaker Model Monitoring Module](modules/sagemaker/sagemaker-model-monitoring-module/README.md) | Deploy data quality, model quality, and model bias monitoring jobs which run against a SageMaker Endpoint. | ### Mlflow Modules diff --git a/modules/sagemaker/sagemaker-model-monitoring/README.md b/modules/sagemaker/sagemaker-model-monitoring/README.md index 3a407ab9..c82532e6 100644 --- a/modules/sagemaker/sagemaker-model-monitoring/README.md +++ b/modules/sagemaker/sagemaker-model-monitoring/README.md @@ -2,12 +2,13 @@ ## Description -This module creates a SageMaker Model Monitoring jobs for (optionally) data quality and model quality. +This module creates SageMaker Model Monitoring jobs for (optionally) data quality, model bias, and model quality. It requires a deployed model endpoint and the proper check steps for each monitoring job: * Data Quality: [QualityCheck step](https://docs.aws.amazon.com/sagemaker/latest/dg/build-and-manage-steps.html#step-type-quality-check) * Model Quality: [QualityCheck step](https://docs.aws.amazon.com/sagemaker/latest/dg/build-and-manage-steps.html#step-type-quality-check) +* Model Bias: [ClarifyCheck step](https://docs.aws.amazon.com/sagemaker/latest/dg/build-and-manage-steps.html#step-type-clarify-check) ### Architecture @@ -30,6 +31,11 @@ One or more of: - `enable-data-quality-monitor`: True to enable the data quality monitoring job. - `enable-model-quality-monitor`: True to enable the model quality monitoring job. +- `enable-model-bias-monitor`: True to enable the model bias monitoring job. + +#### Required for some jobs + +- `ground-truth-prefix`: The S3 prefix in `model-artifacts-bucket-arn` which contains the [ground truth](https://docs.aws.amazon.com/sagemaker/latest/dg/model-monitor-model-quality-merge.html) for captured data. Required if `enable-model-quality-monitor` or `enable-model-bias-monitor` is true. #### Optional @@ -68,15 +74,24 @@ N/A ###### Required -- `ground-truth-prefix`: The S3 prefix in `model-artifacts-bucket-arn` which contains the [ground truth](https://docs.aws.amazon.com/sagemaker/latest/dg/model-monitor-model-quality-merge.html) for captured data. - `model-quality-problem-type`: The machine learning problem type of the model that the monitoring job monitors. -- `model-quality-inference-attribute`: The attribute of the input data that represents the ground truth label. ###### Optional +- `model-quality-inference-attribute`: The attribute of the input data that represents the ground truth label. - `model-quality-probability-attribute`: In a classification problem, the attribute that represents the class probability. - `model-quality-probability-threshold-attribute`: The threshold for the class probability to be evaluated as a positive result. +##### Model Bias Monitoring Job Parameters + +###### Optional + +- `model-bias-checkstep-analysis-config-prefix`: The S3 prefix in `model-artifacts-bucket-arn` which contains the output from the [Clarify Check Step](https://docs.aws.amazon.com/sagemaker/latest/dg/build-and-manage-steps.html#step-type-clarify-check) used for model bias. +- `model-bias-features-attribute`: The attributes of the input data that are the input features. +- `model-bias-inference-attribute`: The attribute of the input data that represents the ground truth label. +- `model-bias-probability-attribute`: In a classification problem, the attribute that represents the class probability. +- `model-bias-probability-threshold-attribute`: The threshold for the class probability to be evaluated as a positive result. + ### Sample manifest declaration ```yaml diff --git a/modules/sagemaker/sagemaker-model-monitoring/sagemaker_model_monitoring/model_bias_construct.py b/modules/sagemaker/sagemaker-model-monitoring/sagemaker_model_monitoring/model_bias_construct.py new file mode 100644 index 00000000..172f1087 --- /dev/null +++ b/modules/sagemaker/sagemaker-model-monitoring/sagemaker_model_monitoring/model_bias_construct.py @@ -0,0 +1,117 @@ +from typing import Any, List, Optional + +from aws_cdk import aws_sagemaker as sagemaker +from constructs import Construct + + +class ModelBiasConstruct(Construct): + """ + CDK construct to define a SageMaker model bias job definition. + + It defines both the model bias job, corresponding schedule, and configuration. + """ + + def __init__( + self, + scope: Construct, + construct_id: str, + clarify_image_uri: str, + endpoint_name: str, + model_bucket_name: str, + model_bias_checkstep_output_prefix: str, + model_bias_checkstep_analysis_config_prefix: Optional[str], + model_bias_output_prefix: str, + ground_truth_prefix: str, + kms_key_id: str, + model_monitor_role_arn: str, + security_group_id: str, + subnet_ids: List[str], + instance_count: int, + instance_type: str, + instance_volume_size_in_gb: int, + max_runtime_in_seconds: int, + features_attribute: Optional[str], + inference_attribute: Optional[str], + probability_attribute: Optional[str], + probability_threshold_attribute: Optional[int], + schedule_expression: str, + **kwargs: Any, + ) -> None: + super().__init__(scope, construct_id, **kwargs) + + # To match the defaults in SageMaker. + if model_bias_checkstep_analysis_config_prefix is None: + model_bias_checkstep_analysis_config_prefix = model_bias_checkstep_output_prefix + + model_bias_job_definition = sagemaker.CfnModelBiasJobDefinition( + self, + "ModelBiasJobDefinition", + job_resources=sagemaker.CfnModelBiasJobDefinition.MonitoringResourcesProperty( + cluster_config=sagemaker.CfnModelBiasJobDefinition.ClusterConfigProperty( + instance_count=instance_count, + instance_type=instance_type, + volume_size_in_gb=instance_volume_size_in_gb, + volume_kms_key_id=kms_key_id, + ) + ), + model_bias_app_specification=sagemaker.CfnModelBiasJobDefinition.ModelBiasAppSpecificationProperty( + config_uri=f"s3://{model_bucket_name}/{model_bias_checkstep_analysis_config_prefix}/analysis_config.json", + image_uri=clarify_image_uri, + ), + model_bias_job_input=sagemaker.CfnModelBiasJobDefinition.ModelBiasJobInputProperty( + ground_truth_s3_input=sagemaker.CfnModelBiasJobDefinition.MonitoringGroundTruthS3InputProperty( + s3_uri=f"s3://{model_bucket_name}/{ground_truth_prefix}" + ), + endpoint_input=sagemaker.CfnModelBiasJobDefinition.EndpointInputProperty( + endpoint_name=endpoint_name, + local_path="/opt/ml/processing/input/model_bias_input", + features_attribute=features_attribute, + inference_attribute=inference_attribute, + probability_attribute=probability_attribute, + probability_threshold_attribute=probability_threshold_attribute, + ), + ), + model_bias_job_output_config=sagemaker.CfnModelBiasJobDefinition.MonitoringOutputConfigProperty( + monitoring_outputs=[ + sagemaker.CfnModelBiasJobDefinition.MonitoringOutputProperty( + s3_output=sagemaker.CfnModelBiasJobDefinition.S3OutputProperty( + local_path="/opt/ml/processing/output/model_bias_output", + s3_uri=f"s3://{model_bucket_name}/{model_bias_output_prefix}", + s3_upload_mode="EndOfJob", + ) + ) + ], + kms_key_id=kms_key_id, + ), + job_definition_name=f"{endpoint_name}-model-bias-def", + role_arn=model_monitor_role_arn, + model_bias_baseline_config=sagemaker.CfnModelBiasJobDefinition.ModelBiasBaselineConfigProperty( + constraints_resource=sagemaker.CfnModelBiasJobDefinition.ConstraintsResourceProperty( + s3_uri=f"s3://{model_bucket_name}/{model_bias_checkstep_output_prefix}/analysis.json" + ) + ), + stopping_condition=sagemaker.CfnModelBiasJobDefinition.StoppingConditionProperty( + max_runtime_in_seconds=max_runtime_in_seconds + ), + network_config=sagemaker.CfnModelBiasJobDefinition.NetworkConfigProperty( + enable_inter_container_traffic_encryption=False, + enable_network_isolation=False, + vpc_config=sagemaker.CfnModelBiasJobDefinition.VpcConfigProperty( + security_group_ids=[security_group_id], subnets=subnet_ids + ), + ), + ) + + model_bias_monitor_schedule = sagemaker.CfnMonitoringSchedule( + self, + "ModelBiasMonitoringSchedule", + monitoring_schedule_config=sagemaker.CfnMonitoringSchedule.MonitoringScheduleConfigProperty( + monitoring_job_definition_name=model_bias_job_definition.job_definition_name, + monitoring_type="ModelBias", + schedule_config=sagemaker.CfnMonitoringSchedule.ScheduleConfigProperty( + schedule_expression=schedule_expression, + ), + ), + monitoring_schedule_name=f"{endpoint_name}-model-bias", + ) + model_bias_monitor_schedule.add_depends_on(model_bias_job_definition) diff --git a/modules/sagemaker/sagemaker-model-monitoring/sagemaker_model_monitoring/settings.py b/modules/sagemaker/sagemaker-model-monitoring/sagemaker_model_monitoring/settings.py index 87a15d8a..74c7731e 100644 --- a/modules/sagemaker/sagemaker-model-monitoring/sagemaker_model_monitoring/settings.py +++ b/modules/sagemaker/sagemaker-model-monitoring/sagemaker_model_monitoring/settings.py @@ -40,6 +40,7 @@ class ModuleSettings(CdkBaseSettings): ground_truth_prefix: str = Field(default="") enable_data_quality_monitor: bool = Field(default=True) enable_model_quality_monitor: bool = Field(default=True) + enable_model_bias_monitor: bool = Field(default=True) # Data quality monitoring options. data_quality_checkstep_output_prefix: str = Field(default="") @@ -49,6 +50,7 @@ class ModuleSettings(CdkBaseSettings): data_quality_instance_volume_size_in_gb: int = Field(default=20, ge=1) data_quality_max_runtime_in_seconds: int = Field(default=3600, ge=1) data_quality_schedule_expression: str = Field(default="cron(0 * ? * * *)") + # Model quality monitoring options. model_quality_checkstep_output_prefix: str = Field(default="") model_quality_output_prefix: str = Field(default="") @@ -62,6 +64,20 @@ class ModuleSettings(CdkBaseSettings): model_quality_probability_threshold_attribute: Optional[int] = Field(default=None) model_quality_schedule_expression: str = Field(default="cron(0 * ? * * *)") + # Model bias monitoring options. + model_bias_checkstep_output_prefix: str = Field(default="") + model_bias_checkstep_analysis_config_prefix: Optional[str] = Field(default=None) + model_bias_output_prefix: str = Field(default="") + model_bias_instance_count: int = Field(default=1, ge=1) + model_bias_instance_type: str = Field(default="ml.m5.large") + model_bias_instance_volume_size_in_gb: int = Field(default=20, ge=1) + model_bias_max_runtime_in_seconds: int = Field(default=1800, ge=1) + model_bias_features_attribute: Optional[str] = Field(default=None) + model_bias_inference_attribute: Optional[str] = Field(default=None) + model_bias_probability_attribute: Optional[str] = Field(default=None) + model_bias_probability_threshold_attribute: Optional[int] = Field(default=None) + model_bias_schedule_expression: str = Field(default="cron(0 * ? * * *)") + tags: Optional[Dict[str, str]] = Field(default=None) diff --git a/modules/sagemaker/sagemaker-model-monitoring/sagemaker_model_monitoring/stack.py b/modules/sagemaker/sagemaker-model-monitoring/sagemaker_model_monitoring/stack.py index 47c3f48e..46291cfe 100644 --- a/modules/sagemaker/sagemaker-model-monitoring/sagemaker_model_monitoring/stack.py +++ b/modules/sagemaker/sagemaker-model-monitoring/sagemaker_model_monitoring/stack.py @@ -8,6 +8,7 @@ from sagemaker import image_uris from sagemaker_model_monitoring.data_quality_construct import DataQualityConstruct +from sagemaker_model_monitoring.model_bias_construct import ModelBiasConstruct from sagemaker_model_monitoring.model_quality_construct import ModelQualityConstruct @@ -17,7 +18,7 @@ class SageMakerModelMonitoringStack(Stack): This stack is deployed to all the deployment environments of the project. - It optionally creates the data quality monitor and model quality monitor + It optionally creates the data quality, model quality, and model bias monitor job constructs and the associated IAM roles and policies. """ @@ -36,6 +37,7 @@ def __init__( ground_truth_prefix: str, enable_data_quality_monitor: bool, enable_model_quality_monitor: bool, + enable_model_bias_monitor: bool, # Data quality monitoring options. data_quality_checkstep_output_prefix: str, data_quality_output_prefix: str, @@ -56,6 +58,19 @@ def __init__( model_quality_probability_attribute: Optional[str], model_quality_probability_threshold_attribute: Optional[int], model_quality_schedule_expression: str, + # Model bias monitoring options. + model_bias_checkstep_output_prefix: str, + model_bias_checkstep_analysis_config_prefix: Optional[str], + model_bias_output_prefix: str, + model_bias_instance_count: int, + model_bias_instance_type: str, + model_bias_instance_volume_size_in_gb: int, + model_bias_max_runtime_in_seconds: int, + model_bias_features_attribute: Optional[str], + model_bias_inference_attribute: Optional[str], + model_bias_probability_attribute: Optional[str], + model_bias_probability_threshold_attribute: Optional[int], + model_bias_schedule_expression: str, **kwargs: Any, ) -> None: super().__init__(scope, id, **kwargs) @@ -70,12 +85,14 @@ def __init__( sagemaker.CfnModel.ContainerDefinitionProperty(model_package_name=model_package_arn) # Error if no monitoring is enabled. - if not enable_data_quality_monitor and not enable_model_quality_monitor: + if not any((enable_data_quality_monitor, enable_model_quality_monitor, enable_model_bias_monitor)): raise ValueError( - "At least one of enable_data_quality_monitor and enable_model_quality_monitor must be True" + "At least one of enable_data_quality_monitor, enable_model_quality_monitor, or " + "enable_model_bias_monitor must be True" ) monitor_image_uri = image_uris.retrieve(framework="model-monitor", region=self.region) + clarify_image_uri = image_uris.retrieve(framework="clarify", region=self.region) model_monitor_policy = iam.ManagedPolicy( self, @@ -190,3 +207,29 @@ def __init__( probability_threshold_attribute=model_quality_probability_threshold_attribute, schedule_expression=model_quality_schedule_expression, ) + + if enable_model_bias_monitor: + ModelBiasConstruct( + self, + "Model Bias Construct", + clarify_image_uri=clarify_image_uri, + endpoint_name=endpoint_name, + model_bucket_name=model_bucket_name, + model_bias_checkstep_output_prefix=model_bias_checkstep_output_prefix, + model_bias_checkstep_analysis_config_prefix=model_bias_checkstep_analysis_config_prefix, + model_bias_output_prefix=model_bias_output_prefix, + ground_truth_prefix=ground_truth_prefix, + kms_key_id=kms_key_id, + model_monitor_role_arn=model_monitor_role.role_arn, + security_group_id=security_group_id, + subnet_ids=subnet_ids, + instance_count=model_bias_instance_count, + instance_type=model_bias_instance_type, + instance_volume_size_in_gb=model_bias_instance_volume_size_in_gb, + max_runtime_in_seconds=model_bias_max_runtime_in_seconds, + features_attribute=model_bias_features_attribute, + inference_attribute=model_bias_inference_attribute, + probability_attribute=model_bias_probability_attribute, + probability_threshold_attribute=model_bias_probability_threshold_attribute, + schedule_expression=model_bias_schedule_expression, + ) diff --git a/modules/sagemaker/sagemaker-model-monitoring/tests/test_app.py b/modules/sagemaker/sagemaker-model-monitoring/tests/test_app.py index d47eac9d..8a64b532 100644 --- a/modules/sagemaker/sagemaker-model-monitoring/tests/test_app.py +++ b/modules/sagemaker/sagemaker-model-monitoring/tests/test_app.py @@ -35,8 +35,7 @@ def test_app(stack_defaults): def test_all_disabled(stack_defaults): os.environ["SEEDFARMER_PARAMETER_ENABLE_DATA_QUALITY_MONITOR"] = "false" os.environ["SEEDFARMER_PARAMETER_ENABLE_MODEL_QUALITY_MONITOR"] = "false" + os.environ["SEEDFARMER_PARAMETER_ENABLE_MODEL_BIAS_MONITOR"] = "false" - with pytest.raises( - Exception, match="At least one of enable_data_quality_monitor and enable_model_quality_monitor must be True" - ): + with pytest.raises(Exception, match="At least one of enable_data_quality_monitor, .+ must be True"): import app # noqa: F401 diff --git a/modules/sagemaker/sagemaker-model-monitoring/tests/test_stack.py b/modules/sagemaker/sagemaker-model-monitoring/tests/test_stack.py index 7bac7afa..9ca3cd61 100644 --- a/modules/sagemaker/sagemaker-model-monitoring/tests/test_stack.py +++ b/modules/sagemaker/sagemaker-model-monitoring/tests/test_stack.py @@ -18,7 +18,9 @@ def stack_defaults() -> None: def stack_model_package_input( - enable_data_quality_monitor: bool = False, enable_model_quality_monitor: bool = False + enable_data_quality_monitor: bool = False, + enable_model_quality_monitor: bool = False, + enable_model_bias_monitor: bool = False, ) -> cdk.Stack: from sagemaker_model_monitoring import settings, stack @@ -48,6 +50,7 @@ def stack_model_package_input( kms_key_id=kms_key_id, enable_data_quality_monitor=enable_data_quality_monitor, enable_model_quality_monitor=enable_model_quality_monitor, + enable_model_bias_monitor=enable_model_bias_monitor, ) return stack.SageMakerModelMonitoringStack( @@ -73,8 +76,16 @@ def test_synthesize_stack_model_quality(stack_defaults: None) -> None: template.resource_count_is("AWS::SageMaker::ModelQualityJobDefinition", 1) +def test_synthesize_stack_model_bias(stack_defaults: None) -> None: + stack = stack_model_package_input(enable_model_bias_monitor=True) + template = Template.from_stack(stack) + template.resource_count_is("AWS::SageMaker::ModelBiasJobDefinition", 1) + + def test_no_cdk_nag_errors(stack_defaults: None) -> None: - stack = stack_model_package_input(enable_data_quality_monitor=True, enable_model_quality_monitor=True) + stack = stack_model_package_input( + enable_data_quality_monitor=True, enable_model_quality_monitor=True, enable_model_bias_monitor=True + ) cdk.Aspects.of(stack).add(cdk_nag.AwsSolutionsChecks()) nag_errors = Annotations.from_stack(stack).find_error(