Skip to content

Commit

Permalink
fix: CDK nags for sagemaker custom kernel module (#43)
Browse files Browse the repository at this point in the history
  • Loading branch information
LeonLuttenberger authored Apr 11, 2024
1 parent 8e8baa8 commit 74ab7b7
Show file tree
Hide file tree
Showing 7 changed files with 73 additions and 80 deletions.
15 changes: 6 additions & 9 deletions modules/sagemaker/sagemaker-custom-kernel/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os

import aws_cdk as cdk
import cdk_nag
from aws_cdk import CfnOutput

from stack import CustomKernelStack
Expand All @@ -25,19 +26,13 @@ def _param(name: str) -> str:
return f"SEEDFARMER_PARAMETER_{name}"


sagemaker_image_name = os.getenv(
_param("SAGEMAKER_IMAGE_NAME"), DEFAULT_SAGEMAKER_IMAGE_NAME
)
sagemaker_image_name = os.getenv(_param("SAGEMAKER_IMAGE_NAME"), DEFAULT_SAGEMAKER_IMAGE_NAME)
ecr_repo_name = os.getenv(_param("ECR_REPO_NAME")) # type: ignore
app_image_config_name = os.getenv(
_param("APP_IMAGE_CONFIG_NAME"), DEFAULT_APP_IMAGE_CONFIG_NAME
)
app_image_config_name = os.getenv(_param("APP_IMAGE_CONFIG_NAME"), DEFAULT_APP_IMAGE_CONFIG_NAME)
custom_kernel_name = os.getenv(_param("CUSTOM_KERNEL_NAME"), DEFAULT_CUSTOM_KERNEL_NAME)
kernel_user_uid = os.getenv(_param("KERNEL_USER_UID"), DEFAULT_USER_UID)
kernel_user_gid = os.getenv(_param("KERNEL_USER_GID"), DEFAULT_USER_GID)
mount_path = os.getenv(
_param("KERNEL_USER_HOME_MOUNT_PATH"), DEFAULT_KERNEL_USER_HOME_MOUNT_PATH
)
mount_path = os.getenv(_param("KERNEL_USER_HOME_MOUNT_PATH"), DEFAULT_KERNEL_USER_HOME_MOUNT_PATH)
sm_studio_domain_id = os.getenv(_param("STUDIO_DOMAIN_ID"))
sm_studio_domain_name = os.getenv(_param("STUDIO_DOMAIN_NAME"))

Expand Down Expand Up @@ -80,4 +75,6 @@ def _param(name: str) -> str:
),
)

cdk.Aspects.of(app).add(cdk_nag.AwsSolutionsChecks(log_ignores=True))

app.synth()
3 changes: 0 additions & 3 deletions modules/sagemaker/sagemaker-custom-kernel/coverage.ini

This file was deleted.

65 changes: 36 additions & 29 deletions modules/sagemaker/sagemaker-custom-kernel/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,35 +1,42 @@
[tool.black]
[tool.ruff]
exclude = [
".eggs",
".git",
".hg",
".mypy_cache",
".ruff_cache",
".tox",
".venv",
"_build",
"buck-out",
"build",
"dist",
"codeseeder",
]
line-length = 120
target-version = ["py36", "py37", "py38", "py39"]
exclude = '''
/(
\.eggs
| \.git
| \.hg
| \.mypy_cache
| \.tox
| \.venv
| \.env
| _build
| buck-out
| build
| dist
| codeseeder.out
)/
'''
target-version = "py38"

[tool.ruff.lint]
select = ["F", "I", "E", "W"]
fixable = ["ALL"]

[tool.mypy]
python_version = "3.8"
strict = true
ignore_missing_imports = true
disallow_untyped_decorators = false
exclude = "codeseeder.out/|example/|tests/|scripts/"
warn_unused_ignores = false

[tool.isort]
multi_line_output = 3
include_trailing_comma = true
force_grid_wrap = 0
use_parentheses = true
ensure_newline_before_comments = true
line_length = 120
py_version = 38
skip_gitignore = false

[tool.pytest.ini_options]
addopts = "-v --cov=. --cov-report term --cov-config=coverage.ini --cov-fail-under=80"
addopts = "-v --cov=. --cov-report term"
pythonpath = [
"."
]
]

[tool.coverage.run]
omit = ["tests/*"]

[tool.coverage.report]
fail_under = 80
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@

# Module Parameters
image_name = os.getenv("SEEDFARMER_PARAMETER_SAGEMAKER_IMAGE_NAME", "echo-kernel")
app_image_config_name = os.getenv(
"SEEDFARMER_PARAMETER_APP_IMAGE_CONFIG_NAME", "echo-kernel"
)
app_image_config_name = os.getenv("SEEDFARMER_PARAMETER_APP_IMAGE_CONFIG_NAME", "echo-kernel")
sm_studio_domain_id = os.environ.get("SEEDFARMER_PARAMETER_STUDIO_DOMAIN_ID")
sm_studio_domain_name = os.environ.get("SEEDFARMER_PARAMETER_STUDIO_DOMAIN_NAME")

Expand Down Expand Up @@ -47,13 +45,9 @@ def update_domain():
merged_distinct_custom_images = list(
dict((v["AppImageConfigName"], v) for v in existing_custom_images).values(),
)
default_user_settings["KernelGatewayAppSettings"]["CustomImages"] = (
merged_distinct_custom_images
)
default_user_settings["KernelGatewayAppSettings"]["CustomImages"] = merged_distinct_custom_images

print(
f"Updating Sagemaker Studio Domain - {sm_studio_domain_name} ({sm_studio_domain_id})"
)
print(f"Updating Sagemaker Studio Domain - {sm_studio_domain_name} ({sm_studio_domain_id})")
print(default_user_settings)
sm_client.update_domain(
DomainId=sm_studio_domain_id,
Expand Down
9 changes: 0 additions & 9 deletions modules/sagemaker/sagemaker-custom-kernel/setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,3 @@ exclude =
codeseeder.out,
bundle
tests

[mypy]
python_version = 3.7
strict = True
ignore_missing_imports = True
allow_untyped_decorators = True
exclude =
codeseeder.out/|example/|tests/|scripts/
warn_unused_ignores = False
29 changes: 11 additions & 18 deletions modules/sagemaker/sagemaker-custom-kernel/stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
import os
from typing import Any

import cdk_nag
from aws_cdk import Aspects, Stack, Tags
from aws_cdk import Stack, Tags
from aws_cdk import aws_ecr as ecr
from aws_cdk import aws_iam as iam
from aws_cdk import aws_sagemaker as sagemaker
Expand Down Expand Up @@ -35,9 +34,7 @@ def __init__(
Tags.of(self).add(key="Deployment", value=app_prefix[:64])

# ECR Image deployment
repo = ecr.Repository.from_repository_name(
self, id=f"{app_prefix}-ecr-repo", repository_name=ecr_repo_name
)
repo = ecr.Repository.from_repository_name(self, id=f"{app_prefix}-ecr-repo", repository_name=ecr_repo_name)

local_image = DockerImageAsset(
self,
Expand All @@ -63,9 +60,7 @@ def __init__(
role_name=f"{app_prefix}-image-role",
assumed_by=iam.ServicePrincipal("sagemaker.amazonaws.com"),
managed_policies=[
iam.ManagedPolicy.from_aws_managed_policy_name(
"AmazonSageMakerFullAccess"
),
iam.ManagedPolicy.from_aws_managed_policy_name("AmazonSageMakerFullAccess"),
],
)

Expand Down Expand Up @@ -107,22 +102,20 @@ def __init__(
)
app_image_config.node.add_dependency(image_version)

Aspects.of(self).add(cdk_nag.AwsSolutionsChecks())
NagSuppressions.add_stack_suppressions(
self,
apply_to_nested_stacks=True,
suppressions=[
NagPackSuppression(
**{
"id": "AwsSolutions-IAM4",
"reason": "Image Role needs Sagemaker Full Access",
}
id="AwsSolutions-IAM4",
reason="Image Role needs Sagemaker Full Access",
applies_to=[
"Policy::arn:<AWS::Partition>:iam::aws:policy/service-role/AWSLambdaBasicExecutionRole",
"Policy::arn:<AWS::Partition>:iam::aws:policy/AmazonSageMakerFullAccess",
],
),
NagPackSuppression(
**{
"id": "AwsSolutions-IAM5",
"reason": "ECR Deployment Service Role needs Full Access",
}
id="AwsSolutions-IAM5",
reason="ECR Deployment Service Role needs Full Access",
),
],
)
20 changes: 17 additions & 3 deletions modules/sagemaker/sagemaker-custom-kernel/tests/test_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
import sys

import aws_cdk as cdk
import cdk_nag
import pytest
from aws_cdk.assertions import Template
from aws_cdk.assertions import Annotations, Match, Template


@pytest.fixture(scope="function")
Expand All @@ -20,7 +21,8 @@ def stack_defaults():
del sys.modules["stack"]


def test_synthesize_stack(stack_defaults):
@pytest.fixture(scope="function")
def stack(stack_defaults) -> cdk.Stack:
import stack

app = cdk.App()
Expand All @@ -36,7 +38,7 @@ def test_synthesize_stack(stack_defaults):
kernel_user_gid = 100
mount_path = "/root"

stack = stack.CustomKernelStack(
return stack.CustomKernelStack(
app,
app_prefix,
env=cdk.Environment(
Expand All @@ -53,9 +55,21 @@ def test_synthesize_stack(stack_defaults):
mount_path=mount_path,
)


def test_synthesize_stack(stack: cdk.Stack) -> None:
template = Template.from_stack(stack)

template.resource_count_is("AWS::IAM::Role", 2)
template.resource_count_is("AWS::SageMaker::Image", 1)
template.resource_count_is("AWS::SageMaker::ImageVersion", 1)
template.resource_count_is("AWS::SageMaker::AppImageConfig", 1)


def test_no_cdk_nag_errors(stack: cdk.Stack) -> None:
cdk.Aspects.of(stack).add(cdk_nag.AwsSolutionsChecks())

nag_errors = Annotations.from_stack(stack).find_error(
"*",
Match.string_like_regexp(r"AwsSolutions-.*"),
)
assert not nag_errors, f"Found {len(nag_errors)} CDK nag errors"

0 comments on commit 74ab7b7

Please sign in to comment.