Skip to content

Commit

Permalink
Add vertex runner (#429)
Browse files Browse the repository at this point in the history
adresses: #417
  • Loading branch information
GeorgesLorre committed Oct 2, 2023
1 parent a39cdbc commit b317152
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 5 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.8', '3.9', '3.10', '3.11']
python-version: ['3.8', '3.9', '3.10']
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
Expand Down
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,14 @@ classifiers = [
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Topic :: Software Development",
"Topic :: Software Development :: Libraries",
"Topic :: Software Development :: Libraries :: Python Modules",
"Typing :: Typed",
]

[tool.poetry.dependencies]
python = ">= 3.8 < 3.12"
python = ">= 3.8 < 3.10"
dask = {extras = ["dataframe", "distributed", "diagnostics"], version = ">= 2023.4.1"}
importlib-resources = { version = ">= 1.3", python = "<3.9" }
jsonschema = ">= 4.18"
Expand All @@ -53,13 +52,14 @@ s3fs = { version = ">= 2023.4.0", optional = true }
adlfs = { version = ">= 2023.4.0", optional = true }
kfp = { version = "2.0.1", optional = true }
pandas = { version = ">= 1.3.5", optional = true }
google-cloud-aiplatform = { version = "1.32.0", optional = true}

[tool.poetry.extras]
aws = ["fsspec", "s3fs"]
azure = ["fsspec", "adlfs"]
gcp = ["fsspec", "gcsfs"]
kfp = ["kfp"]
vertex = ["kfp"]
vertex = ["kfp", "google-cloud-aiplatform"]

[tool.poetry.group.test.dependencies]
pre-commit = "^3.1.1"
Expand Down
36 changes: 36 additions & 0 deletions src/fondant/runner.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import subprocess # nosec
import typing as t
from abc import ABC, abstractmethod

import yaml
Expand Down Expand Up @@ -85,3 +86,38 @@ def get_name_from_spec(self, input_spec: str):
with open(input_spec) as f:
spec = yaml.safe_load(f)
return spec["pipelineInfo"]["name"]


class VertexRunner(Runner):
def __resolve_imports(self):
import google.cloud.aiplatform as aip

self.aip = aip

def __init__(
self,
project_id: str,
project_region: str,
service_account: t.Optional[str] = None,
):
self.__resolve_imports()

self.aip.init(
project=project_id,
location=project_region,
)
self.service_account = service_account

def run(self, input_spec: str, *args, **kwargs):
job = self.aip.PipelineJob(
display_name=self.get_name_from_spec(input_spec),
template_path=input_spec,
enable_caching=False,
)
job.submit(service_account=self.service_account)

def get_name_from_spec(self, input_spec: str):
"""Get the name of the pipeline from the spec."""
with open(input_spec) as f:
spec = yaml.safe_load(f)
return spec["pipelineInfo"]["name"]
19 changes: 18 additions & 1 deletion tests/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from unittest import mock

import pytest
from fondant.runner import DockerRunner, KubeflowRunner
from fondant.runner import DockerRunner, KubeflowRunner, VertexRunner

VALID_PIPELINE = Path("./tests/example_pipelines/compiled_pipeline/")

Expand Down Expand Up @@ -79,3 +79,20 @@ def test_kfp_import():
sys.modules["kfp"] = None
with pytest.raises(ImportError):
_ = KubeflowRunner(host="some_host")


def test_vertex_runner():
input_spec_path = str(VALID_PIPELINE / "kubeflow_pipeline.yml")
with mock.patch("google.cloud.aiplatform.init", return_value=None), mock.patch(
"google.cloud.aiplatform.PipelineJob",
):
runner = VertexRunner(project_id="some_project", project_region="some_region")
runner.run(input_spec=input_spec_path)

# test with service account
runner2 = VertexRunner(
project_id="some_project",
project_region="some_region",
service_account="some_account",
)
runner2.run(input_spec=input_spec_path)

0 comments on commit b317152

Please sign in to comment.