From 067271255e2e8d99e75f7e8483611059c8522196 Mon Sep 17 00:00:00 2001 From: Georges Lorre Date: Mon, 4 Dec 2023 11:57:30 +0100 Subject: [PATCH 1/5] Add functionality for pullthrough cache rule creation and URI patching --- src/fondant/pipeline/compiler.py | 46 +++++++++++++++++++++++++++++++- 1 file changed, 45 insertions(+), 1 deletion(-) diff --git a/src/fondant/pipeline/compiler.py b/src/fondant/pipeline/compiler.py index 4e94fc18c..4f05bdd09 100644 --- a/src/fondant/pipeline/compiler.py +++ b/src/fondant/pipeline/compiler.py @@ -521,15 +521,20 @@ def _set_configuration(self, task, fondant_component_operation): class SagemakerCompiler(Compiler): def __init__(self): + self.ecr_namespace = "fndnt-mirror" self._resolve_imports() + self.ecr_client = self.boto3.client("ecr") + self._check_ecr_pull_through_rule(namespace=self.ecr_namespace) def _resolve_imports(self): try: + import boto3 import sagemaker import sagemaker.processing import sagemaker.workflow.pipeline import sagemaker.workflow.steps + self.boto3 = boto3 self.sagemaker = sagemaker except ImportError: @@ -578,6 +583,43 @@ def _get_build_command( return command + def _check_ecr_pull_through_rule(self, namespace: str) -> None: + logging.info("Checking existing pull through cache rules for 'fndnt-mirror'") + + try: + self.ecr_client.describe_pull_through_cache_rules( + ecrRepositoryPrefixes=[namespace], + ) + except self.ecr_client.exceptions._code_to_exception[ + "PullThroughCacheRuleNotFoundException" + ]: + logging.info( + """Pull through cache rule for 'fndnt-mirror' not found.. + creating pull through cache rule for 'fndnt-mirror'""", + ) + + self.ecr_client.create_pull_through_cache_rule( + ecrRepositoryPrefix=namespace, + upstreamRegistryUrl="public.ecr.aws", + ) + + logging.info( + "Pull through cache rule for 'fndnt-mirror' created successfully", + ) + + def _patch_uri(self, og_uri: str) -> str: + uri, tag = og_uri.split(":") + + # force pullthrough cache to be used + _ = self.ecr_client.batch_get_image( + repositoryName=f"{self.ecr_namespace}/{uri}", + imageIds=[{"imageTag": tag}], + ) + repo_response = self.ecr_client.describe_repositories( + repositoryNames=[f"{self.ecr_namespace}/{uri}"], + ) + return repo_response["repositories"][0]["repositoryUri"] + ":" + tag + def compile( self, pipeline: Pipeline, @@ -641,19 +683,21 @@ def compile( role_arn = self.sagemaker.get_execution_role() processor = self.sagemaker.processing.ScriptProcessor( - image_uri=component_op.component_spec.image, + image_uri=self._patch_uri(component_op.component_spec.image), command=["bash"], instance_type=instance_type, instance_count=1, base_job_name=component_name, role=role_arn, ) + step = self.sagemaker.workflow.steps.ProcessingStep( name=component_name, processor=processor, depends_on=depends_on, code=script_path, ) + steps.append(step) sagemaker_pipeline = self.sagemaker.workflow.pipeline.Pipeline( From eb9497fca40d4b2a4fa13c0ee5c9e951fa621fdf Mon Sep 17 00:00:00 2001 From: Georges Lorre Date: Mon, 4 Dec 2023 12:01:14 +0100 Subject: [PATCH 2/5] Update log messages --- src/fondant/pipeline/compiler.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/fondant/pipeline/compiler.py b/src/fondant/pipeline/compiler.py index 4f05bdd09..5487091cc 100644 --- a/src/fondant/pipeline/compiler.py +++ b/src/fondant/pipeline/compiler.py @@ -584,7 +584,7 @@ def _get_build_command( return command def _check_ecr_pull_through_rule(self, namespace: str) -> None: - logging.info("Checking existing pull through cache rules for 'fndnt-mirror'") + logging.info(f"Checking existing pull through cache rules for '{namespace}'") try: self.ecr_client.describe_pull_through_cache_rules( @@ -594,8 +594,8 @@ def _check_ecr_pull_through_rule(self, namespace: str) -> None: "PullThroughCacheRuleNotFoundException" ]: logging.info( - """Pull through cache rule for 'fndnt-mirror' not found.. - creating pull through cache rule for 'fndnt-mirror'""", + f"""Pull through cache rule for '{namespace}' not found.. + creating pull through cache rule for '{namespace}'""", ) self.ecr_client.create_pull_through_cache_rule( @@ -604,7 +604,7 @@ def _check_ecr_pull_through_rule(self, namespace: str) -> None: ) logging.info( - "Pull through cache rule for 'fndnt-mirror' created successfully", + f"Pull through cache rule for '{namespace}' created successfully", ) def _patch_uri(self, og_uri: str) -> str: From 22791e49c12b7982d1c0ff02cd027916f631fd7b Mon Sep 17 00:00:00 2001 From: Georges Lorre Date: Mon, 4 Dec 2023 13:40:16 +0100 Subject: [PATCH 3/5] Add base_path validation --- src/fondant/pipeline/compiler.py | 31 +++++++++++++++++++++++-------- 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/src/fondant/pipeline/compiler.py b/src/fondant/pipeline/compiler.py index 5487091cc..b3dcfe5fb 100644 --- a/src/fondant/pipeline/compiler.py +++ b/src/fondant/pipeline/compiler.py @@ -524,7 +524,6 @@ def __init__(self): self.ecr_namespace = "fndnt-mirror" self._resolve_imports() self.ecr_client = self.boto3.client("ecr") - self._check_ecr_pull_through_rule(namespace=self.ecr_namespace) def _resolve_imports(self): try: @@ -583,28 +582,30 @@ def _get_build_command( return command - def _check_ecr_pull_through_rule(self, namespace: str) -> None: - logging.info(f"Checking existing pull through cache rules for '{namespace}'") + def _check_ecr_pull_through_rule(self) -> None: + logging.info( + f"Checking existing pull through cache rules for '{self.ecr_namespace}'", + ) try: self.ecr_client.describe_pull_through_cache_rules( - ecrRepositoryPrefixes=[namespace], + ecrRepositoryPrefixes=[self.ecr_namespace], ) except self.ecr_client.exceptions._code_to_exception[ "PullThroughCacheRuleNotFoundException" ]: logging.info( - f"""Pull through cache rule for '{namespace}' not found.. - creating pull through cache rule for '{namespace}'""", + f"""Pull through cache rule for '{self.ecr_namespace}' not found.. + creating pull through cache rule for '{self.ecr_namespace}'""", ) self.ecr_client.create_pull_through_cache_rule( - ecrRepositoryPrefix=namespace, + ecrRepositoryPrefix=self.ecr_namespace, upstreamRegistryUrl="public.ecr.aws", ) logging.info( - f"Pull through cache rule for '{namespace}' created successfully", + f"Pull through cache rule for '{self.ecr_namespace}' created successfully", ) def _patch_uri(self, og_uri: str) -> str: @@ -620,6 +621,17 @@ def _patch_uri(self, og_uri: str) -> str: ) return repo_response["repositories"][0]["repositoryUri"] + ":" + tag + def validate_base_path(self, base_path: str) -> None: + file_prefix, storage_path = base_path.split("://") + + if file_prefix != "s3": + msg = "base_path must be a valid s3 path, starting with s3://" + raise ValueError(msg) + + if storage_path.endswith("/"): + msg = "base_path must not end with a '/'" + raise ValueError(msg) + def compile( self, pipeline: Pipeline, @@ -639,6 +651,9 @@ def compile( role_arn: the Amazon Resource Name role to use for the processing steps, if none provided the `sagemaker.get_execution_role()` role will be used. """ + self.validate_base_path(pipeline.base_path) + self._check_ecr_pull_through_rule() + run_id = pipeline.get_run_id() path = pipeline.base_path pipeline.validate(run_id=run_id) From 785c3af9893a52cf23df62b2114f371f4381e0a2 Mon Sep 17 00:00:00 2001 From: Georges Lorre Date: Mon, 4 Dec 2023 13:54:08 +0100 Subject: [PATCH 4/5] Fix tests --- src/fondant/pipeline/compiler.py | 2 +- tests/pipeline/test_compiler.py | 17 +++++++++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/src/fondant/pipeline/compiler.py b/src/fondant/pipeline/compiler.py index b3dcfe5fb..299250dc0 100644 --- a/src/fondant/pipeline/compiler.py +++ b/src/fondant/pipeline/compiler.py @@ -523,7 +523,6 @@ class SagemakerCompiler(Compiler): def __init__(self): self.ecr_namespace = "fndnt-mirror" self._resolve_imports() - self.ecr_client = self.boto3.client("ecr") def _resolve_imports(self): try: @@ -651,6 +650,7 @@ def compile( role_arn: the Amazon Resource Name role to use for the processing steps, if none provided the `sagemaker.get_execution_role()` role will be used. """ + self.ecr_client = self.boto3.client("ecr") self.validate_base_path(pipeline.base_path) self._check_ecr_pull_through_rule() diff --git a/tests/pipeline/test_compiler.py b/tests/pipeline/test_compiler.py index f087059fe..539155ff8 100644 --- a/tests/pipeline/test_compiler.py +++ b/tests/pipeline/test_compiler.py @@ -656,3 +656,20 @@ def test_sagemaker_generate_script(tmp_path_factory): with open(script_path) as f: assert f.read() == "fondant execute main echo hello world" + + +def test_sagemaker_base_path_validator(): + compiler = SagemakerCompiler() + + # no lowercase 's3' + with pytest.raises( + ValueError, + match="base_path must be a valid s3 path, starting with s3://", + ): + compiler.validate_base_path("S3://foo/bar") + # ends with '/' + with pytest.raises(ValueError, match="base_path must not end with a '/'"): + compiler.validate_base_path("s3://foo/bar/") + + # valid + compiler.validate_base_path("s3://foo/bar") From cebb3d95480d4457fde5f9fb0f06b10fb6ed57c5 Mon Sep 17 00:00:00 2001 From: Georges Lorre Date: Mon, 4 Dec 2023 15:42:27 +0100 Subject: [PATCH 5/5] Fix sagemaker cli runner --- src/fondant/cli.py | 1 + src/fondant/pipeline/compiler.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/fondant/cli.py b/src/fondant/cli.py index 7b412857a..140015d12 100644 --- a/src/fondant/cli.py +++ b/src/fondant/cli.py @@ -640,6 +640,7 @@ def register_run(parent_parser): local_parser.set_defaults(func=run_local) kubeflow_parser.set_defaults(func=run_kfp) vertex_parser.set_defaults(func=run_vertex) + sagemaker_parser.set_defaults(func=run_sagemaker) def run_local(args): diff --git a/src/fondant/pipeline/compiler.py b/src/fondant/pipeline/compiler.py index 299250dc0..80d09b2e2 100644 --- a/src/fondant/pipeline/compiler.py +++ b/src/fondant/pipeline/compiler.py @@ -519,7 +519,7 @@ def _set_configuration(self, task, fondant_component_operation): return task -class SagemakerCompiler(Compiler): +class SagemakerCompiler(Compiler): # pragma: no cover def __init__(self): self.ecr_namespace = "fndnt-mirror" self._resolve_imports()