Skip to content

Commit

Permalink
Add functionality for pullthrough cache rule creation and URI patching (
Browse files Browse the repository at this point in the history
  • Loading branch information
GeorgesLorre authored Dec 4, 2023
1 parent 2b47ba0 commit dd738f9
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 2 deletions.
1 change: 1 addition & 0 deletions src/fondant/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,7 @@ def register_run(parent_parser):
local_parser.set_defaults(func=run_local)
kubeflow_parser.set_defaults(func=run_kfp)
vertex_parser.set_defaults(func=run_vertex)
sagemaker_parser.set_defaults(func=run_sagemaker)


def run_local(args):
Expand Down
63 changes: 61 additions & 2 deletions src/fondant/pipeline/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,17 +519,20 @@ def _set_configuration(self, task, fondant_component_operation):
return task


class SagemakerCompiler(Compiler):
class SagemakerCompiler(Compiler): # pragma: no cover
def __init__(self):
self.ecr_namespace = "fndnt-mirror"
self._resolve_imports()

def _resolve_imports(self):
try:
import boto3
import sagemaker
import sagemaker.processing
import sagemaker.workflow.pipeline
import sagemaker.workflow.steps

self.boto3 = boto3
self.sagemaker = sagemaker

except ImportError:
Expand Down Expand Up @@ -578,6 +581,56 @@ def _get_build_command(

return command

def _check_ecr_pull_through_rule(self) -> None:
logging.info(
f"Checking existing pull through cache rules for '{self.ecr_namespace}'",
)

try:
self.ecr_client.describe_pull_through_cache_rules(
ecrRepositoryPrefixes=[self.ecr_namespace],
)
except self.ecr_client.exceptions._code_to_exception[
"PullThroughCacheRuleNotFoundException"
]:
logging.info(
f"""Pull through cache rule for '{self.ecr_namespace}' not found..
creating pull through cache rule for '{self.ecr_namespace}'""",
)

self.ecr_client.create_pull_through_cache_rule(
ecrRepositoryPrefix=self.ecr_namespace,
upstreamRegistryUrl="public.ecr.aws",
)

logging.info(
f"Pull through cache rule for '{self.ecr_namespace}' created successfully",
)

def _patch_uri(self, og_uri: str) -> str:
uri, tag = og_uri.split(":")

# force pullthrough cache to be used
_ = self.ecr_client.batch_get_image(
repositoryName=f"{self.ecr_namespace}/{uri}",
imageIds=[{"imageTag": tag}],
)
repo_response = self.ecr_client.describe_repositories(
repositoryNames=[f"{self.ecr_namespace}/{uri}"],
)
return repo_response["repositories"][0]["repositoryUri"] + ":" + tag

def validate_base_path(self, base_path: str) -> None:
file_prefix, storage_path = base_path.split("://")

if file_prefix != "s3":
msg = "base_path must be a valid s3 path, starting with s3://"
raise ValueError(msg)

if storage_path.endswith("/"):
msg = "base_path must not end with a '/'"
raise ValueError(msg)

def compile(
self,
pipeline: Pipeline,
Expand All @@ -597,6 +650,10 @@ def compile(
role_arn: the Amazon Resource Name role to use for the processing steps,
if none provided the `sagemaker.get_execution_role()` role will be used.
"""
self.ecr_client = self.boto3.client("ecr")
self.validate_base_path(pipeline.base_path)
self._check_ecr_pull_through_rule()

run_id = pipeline.get_run_id()
path = pipeline.base_path
pipeline.validate(run_id=run_id)
Expand Down Expand Up @@ -641,19 +698,21 @@ def compile(
role_arn = self.sagemaker.get_execution_role()

processor = self.sagemaker.processing.ScriptProcessor(
image_uri=component_op.component_spec.image,
image_uri=self._patch_uri(component_op.component_spec.image),
command=["bash"],
instance_type=instance_type,
instance_count=1,
base_job_name=component_name,
role=role_arn,
)

step = self.sagemaker.workflow.steps.ProcessingStep(
name=component_name,
processor=processor,
depends_on=depends_on,
code=script_path,
)

steps.append(step)

sagemaker_pipeline = self.sagemaker.workflow.pipeline.Pipeline(
Expand Down
17 changes: 17 additions & 0 deletions tests/pipeline/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,3 +656,20 @@ def test_sagemaker_generate_script(tmp_path_factory):

with open(script_path) as f:
assert f.read() == "fondant execute main echo hello world"


def test_sagemaker_base_path_validator():
compiler = SagemakerCompiler()

# no lowercase 's3'
with pytest.raises(
ValueError,
match="base_path must be a valid s3 path, starting with s3://",
):
compiler.validate_base_path("S3://foo/bar")
# ends with '/'
with pytest.raises(ValueError, match="base_path must not end with a '/'"):
compiler.validate_base_path("s3://foo/bar/")

# valid
compiler.validate_base_path("s3://foo/bar")

0 comments on commit dd738f9

Please sign in to comment.