diff --git a/pyproject.toml b/pyproject.toml index e735d3c52..7ac3f3cd9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,6 +59,7 @@ docker = {version = ">= 6.1.3", optional = true } kfp = { version = "2.3.0", optional = true, extras =["kubernetes"] } google-cloud-aiplatform = { version = "1.34.0", optional = true} sagemaker = {version = ">= 2.197.0", optional = true} +boto3 = {version = "1.28.64", optional = true} [tool.poetry.extras] component = ["dask"] @@ -69,7 +70,7 @@ gcp = ["gcsfs"] kfp = ["docker", "kfp"] vertex = ["docker", "kfp", "google-cloud-aiplatform"] -sagemaker = ["sagemaker"] +sagemaker = ["sagemaker", "boto3"] docker = ["docker"] [tool.poetry.group.test.dependencies] diff --git a/src/fondant/pipeline/runner.py b/src/fondant/pipeline/runner.py index c79dad0ef..108305c22 100644 --- a/src/fondant/pipeline/runner.py +++ b/src/fondant/pipeline/runner.py @@ -65,7 +65,7 @@ def run( experiment = self.client.get_experiment(experiment_name=experiment_name) except ValueError: logger.info( - f"Defined experiment '{experiment_name}' not found. Creating new experiment" + f"defined experiment '{experiment_name}' not found. creating new experiment" f" under this name", ) experiment = self.client.create_experiment(experiment_name) @@ -126,3 +126,49 @@ def get_name_from_spec(self, input_spec: str): with open(input_spec) as f: spec = yaml.safe_load(f) return spec["pipelineInfo"]["name"] + + +class SagemakerRunner(Runner): + def __init__(self): + self.__resolve_imports() + self.client = self.boto3.client("sagemaker") + + def __resolve_imports(self): + import boto3 + + self.boto3 = boto3 + + def run(self, input_spec: str, pipeline_name: str, role_arn: str, *args, **kwargs): + """Creates/updates a sagemaker pipeline and execute it.""" + with open(input_spec) as f: + pipeline = f.read() + pipelines = self.client.list_pipelines( + PipelineNamePrefix=pipeline_name, + ) + if pipelines["PipelineSummaries"]: + logging.info( + f"Pipeline with name {pipeline_name} already exists, updating it", + ) + _ = self.client.update_pipeline( + PipelineName=pipeline_name, + PipelineDefinition=pipeline, + RoleArn=role_arn, + ) + else: + logging.info( + f"Pipeline with name {pipeline_name} does not exist, creating it", + ) + _ = self.client.create_pipeline( + PipelineName=pipeline_name, + PipelineDefinition=pipeline, + RoleArn=role_arn, + ) + + logging.info(f"Starting pipeline execution for pipeline {pipeline_name}") + _ = self.client.start_pipeline_execution( + PipelineName=pipeline_name, + ParallelismConfiguration={"MaxParallelExecutionSteps": 1}, + ) + logging.info( + "Pipeline execution started for pipeline, visit Sagemaker studio to follow up", + ) diff --git a/tests/test_runner.py b/tests/test_runner.py index 15bc4e89a..84ad63304 100644 --- a/tests/test_runner.py +++ b/tests/test_runner.py @@ -4,7 +4,12 @@ from unittest import mock import pytest -from fondant.pipeline.runner import DockerRunner, KubeflowRunner, VertexRunner +from fondant.pipeline.runner import ( + DockerRunner, + KubeflowRunner, + SagemakerRunner, + VertexRunner, +) VALID_PIPELINE = Path("./tests/example_pipelines/compiled_pipeline/") @@ -96,3 +101,58 @@ def test_vertex_runner(): service_account="some_account", ) runner2.run(input_spec=input_spec_path) + + +def test_sagemaker_runner(tmp_path_factory): + with mock.patch("boto3.client", spec=True), tmp_path_factory.mktemp( + "temp", + ) as tmpdir: + # create a small temporary spec file + with open(tmpdir / "spec.json", "w") as f: + f.write('{"pipelineInfo": {"name": "pipeline_1"}}') + runner = SagemakerRunner() + + runner.run( + input_spec=tmpdir / "spec.json", + pipeline_name="pipeline_1", + role_arn="arn:something", + ) + + # check which methods were called on the client + assert runner.client.method_calls == [ + mock.call.list_pipelines(PipelineNamePrefix="pipeline_1"), + mock.call.update_pipeline( + PipelineName="pipeline_1", + PipelineDefinition='{"pipelineInfo": {"name": "pipeline_1"}}', + RoleArn="arn:something", + ), + mock.call.start_pipeline_execution( + PipelineName="pipeline_1", + ParallelismConfiguration={"MaxParallelExecutionSteps": 1}, + ), + ] + + # reset the mock and test the creation of a new pipeline + runner.client.reset_mock() + runner.client.configure_mock( + **{"list_pipelines.return_value": {"PipelineSummaries": []}}, + ) + + runner.run( + input_spec=tmpdir / "spec.json", + pipeline_name="pipeline_1", + role_arn="arn:something", + ) + # here we expect the create_pipeline method to be called + assert runner.client.method_calls == [ + mock.call.list_pipelines(PipelineNamePrefix="pipeline_1"), + mock.call.create_pipeline( + PipelineName="pipeline_1", + PipelineDefinition='{"pipelineInfo": {"name": "pipeline_1"}}', + RoleArn="arn:something", + ), + mock.call.start_pipeline_execution( + PipelineName="pipeline_1", + ParallelismConfiguration={"MaxParallelExecutionSteps": 1}, + ), + ]