Skip to content

Commit

Permalink
Feature/migrate to pydantic v2 (#176)
Browse files Browse the repository at this point in the history
* Assign default values to fields, update grouping config cls field validator.

* Skip validating invalid grouping config in the test.

* Fix order of fields in grouping config.

* Fix how the test config is accessed.

* Cleanup old values from test config.

* Specify that pydantic version should be smaller than 3.

* Update changelog.
  • Loading branch information
sfczekalski authored Sep 3, 2024
1 parent 1d3645c commit 59075c0
Show file tree
Hide file tree
Showing 11 changed files with 181 additions and 98 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ and it's only available in `kfp-kubernetes` extension package
didn't alter the remote pipeline execution, and only escaped the local Python processs. The timeout funcionality will be added later on,
with the proper remote pipeline execution handling, and possibly per-task timeout enabled by [the new kfp feature](https://github.com/kubeflow/pipelines/pull/10481).
- Assign pipelines to Vertex AI experiments
- Migrated `pydantic` library to v2

## [0.11.1] - 2024-07-01

Expand Down
26 changes: 13 additions & 13 deletions kedro_vertexai/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from inspect import signature
from typing import Dict, List, Optional

from pydantic import BaseModel, validator
from pydantic import BaseModel, field_validator
from pydantic.networks import IPvAnyAddress

DEFAULT_CONFIG_TEMPLATE = """
Expand Down Expand Up @@ -166,13 +166,13 @@ class GroupingConfig(BaseModel):
cls: str = "kedro_vertexai.grouping.IdentityNodeGrouper"
params: Optional[dict] = {}

@validator("cls")
@field_validator("cls")
def class_valid(cls, v, values, **kwargs):
try:
grouper_class = dynamic_load_class(v)
class_sig = signature(grouper_class)
if "params" in values:
class_sig.bind(None, **values["params"])
if "params" in values.data:
class_sig.bind(None, **values.data["params"])
else:
class_sig.bind(None)
except: # noqa: E722
Expand All @@ -196,13 +196,13 @@ class HostAliasConfig(BaseModel):


class ResourcesConfig(BaseModel):
cpu: Optional[str]
gpu: Optional[str]
memory: Optional[str]
cpu: Optional[str] = None
gpu: Optional[str] = None
memory: Optional[str] = None


class NetworkConfig(BaseModel):
vpc: Optional[str]
vpc: Optional[str] = None
host_aliases: Optional[List[HostAliasConfig]] = []


Expand All @@ -212,7 +212,7 @@ class DynamicConfigProviderConfig(BaseModel):


class MLFlowVertexAIConfig(BaseModel):
request_header_provider_params: Optional[Dict[str, str]]
request_header_provider_params: Optional[Dict[str, str]] = None


class ScheduleConfig(BaseModel):
Expand All @@ -227,13 +227,13 @@ class ScheduleConfig(BaseModel):

class RunConfig(BaseModel):
image: str
root: Optional[str]
description: Optional[str]
root: Optional[str] = None
description: Optional[str] = None
experiment_name: str
experiment_description: Optional[str] = None
scheduled_run_name: Optional[str]
scheduled_run_name: Optional[str] = None
grouping: Optional[GroupingConfig] = GroupingConfig()
service_account: Optional[str]
service_account: Optional[str] = None
network: Optional[NetworkConfig] = NetworkConfig()
ttl: int = 3600 * 24 * 7
resources: Optional[Dict[str, ResourcesConfig]] = dict(
Expand Down
2 changes: 1 addition & 1 deletion kedro_vertexai/context_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def config(self) -> PluginConfig:
"Missing vertexai.yml files in configuration. "
"Make sure that you configure your project first"
)
return PluginConfig.parse_obj(vertex_conf)
return PluginConfig.model_validate(vertex_conf)

@cached_property
def vertexai_client(self) -> VertexAIPipelinesClient:
Expand Down
168 changes: 125 additions & 43 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ tabulate = ">=0.8.7"
semver = ">=2.10,<4.0.0"
toposort = ">1.0,<2.0"
pyarrow = ">=14.0.1" # Stating explicitly for sub-dependency due to critical vulnerability
pydantic = ">=1.9.0,<2.0.0" # so far blocked by kedro-mlflow at 0.11.10 & kfp <2.0
pydantic = ">=2,<3"
google-auth = "<3"
google-cloud-scheduler = ">=2.3.2"
google-cloud-iam = "<3"
Expand Down
21 changes: 11 additions & 10 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import unittest
from collections import namedtuple
from copy import deepcopy
from itertools import product
from pathlib import Path
from tempfile import TemporaryDirectory
Expand Down Expand Up @@ -41,7 +42,7 @@ def test_list_pipelines(self):

def test_run_once(self):
context_helper: ContextHelper = MagicMock(ContextHelper)
context_helper.config = test_config
context_helper.config = deepcopy(test_config)
config = dict(context_helper=context_helper)
runner = CliRunner()

Expand All @@ -67,7 +68,7 @@ def test_run_once(self):

def test_run_once_with_wait(self):
context_helper: ContextHelper = MagicMock(ContextHelper)
context_helper.config = test_config
context_helper.config = deepcopy(test_config)
config = dict(context_helper=context_helper)
runner = CliRunner()

Expand Down Expand Up @@ -109,7 +110,7 @@ def test_docker_push(self):

def test_run_once_auto_build(self):
context_helper: ContextHelper = MagicMock(ContextHelper)
context_helper.config = test_config
context_helper.config = deepcopy(test_config)
config = dict(context_helper=context_helper)
runner = CliRunner()

Expand Down Expand Up @@ -146,7 +147,7 @@ def test_run_once_auto_build(self):
@patch("webbrowser.open_new_tab")
def test_ui(self, open_new_tab):
context_helper = MagicMock(ContextHelper)
context_helper.config = test_config
context_helper.config = deepcopy(test_config)
config = dict(context_helper=context_helper)
runner = CliRunner()

Expand All @@ -160,7 +161,7 @@ def test_ui(self, open_new_tab):

def test_compile(self):
context_helper: ContextHelper = MagicMock(ContextHelper)
context_helper.config = test_config
context_helper.config = deepcopy(test_config)
config = dict(context_helper=context_helper)
runner = CliRunner()

Expand All @@ -177,7 +178,7 @@ def test_compile(self):

def test_store_params_empty(self):
context_helper: ContextHelper = MagicMock(ContextHelper)
context_helper.config = test_config
context_helper.config = deepcopy(test_config)
config = dict(context_helper=context_helper)
runner = CliRunner()

Expand Down Expand Up @@ -215,7 +216,7 @@ def test_store_params_exiting_config_yaml(self):
Covers the case when there is an exiting config.yaml in the pwd
"""
context_helper: ContextHelper = MagicMock(ContextHelper)
context_helper.config = test_config
context_helper.config = deepcopy(test_config)
config = dict(context_helper=context_helper)
runner = CliRunner()

Expand Down Expand Up @@ -260,7 +261,7 @@ def test_store_params_exiting_config_yaml(self):

def test_schedule(self):
context_helper: ContextHelper = MagicMock(ContextHelper)
context_helper.config = test_config
context_helper.config = deepcopy(test_config)

mock_schedule = MagicMock()
context_helper.config.run_config.schedules = {
Expand Down Expand Up @@ -312,7 +313,7 @@ def test_schedule(self):
@patch.object(Path, "cwd")
def test_init(self, cwd):
context_helper: ContextHelper = MagicMock(ContextHelper)
context_helper.config = test_config
context_helper.config = deepcopy(test_config)
context_helper.context.project_name = "Test Project"
context_helper.context.project_path.name = "test_project_path"
config = dict(context_helper=context_helper)
Expand All @@ -333,7 +334,7 @@ def test_init(self, cwd):
@patch.object(Path, "cwd")
def test_init_with_github_actions(self, cwd):
context_helper: ContextHelper = MagicMock(ContextHelper)
context_helper.config = test_config
context_helper.config = deepcopy(test_config)
context_helper.context.project_name = "Test Project"
context_helper.context.project_path.name = "test_project_path"
config = dict(context_helper=context_helper)
Expand Down
43 changes: 23 additions & 20 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
description: "My awesome pipeline"
service_account: [email protected]
grouping:
cls: kedro_vertexai.grouping.IdentityNodeGrouper
cls: "kedro_vertexai.grouping.IdentityNodeGrouper"
params:
tag_prefix: "group."
ttl: 300
network:
vpc: my-vpc
Expand Down Expand Up @@ -49,26 +51,25 @@

class TestPluginConfig(unittest.TestCase):
def test_grouping_config(self):
cfg = PluginConfig.parse_obj(yaml.safe_load(CONFIG_MINIMAL))
cfg = PluginConfig.model_validate(yaml.safe_load(CONFIG_MINIMAL))
assert cfg.run_config.grouping is not None
assert (
cfg.run_config.grouping.cls == "kedro_vertexai.grouping.IdentityNodeGrouper"
)
c_obj = dynamic_init_class(cfg.run_config.grouping.cls, None)
assert isinstance(c_obj, IdentityNodeGrouper)

cfg_tag_group = """
project_id: some-project
region: some-region
run_config:
image: test
experiment_name: test
grouping:
cls: "kedro_vertexai.grouping.TagNodeGrouper"
cls: kedro_vertexai.grouping.TagNodeGrouper
params:
tag_prefix: "group."
"""
cfg = PluginConfig.parse_obj(yaml.safe_load(cfg_tag_group))
cfg = PluginConfig.model_validate(yaml.safe_load(cfg_tag_group))
assert cfg.run_config.grouping is not None
c_obj = dynamic_init_class(
cfg.run_config.grouping.cls, None, **cfg.run_config.grouping.params
Expand All @@ -89,16 +90,18 @@ def test_grouping_config_error(self, log_error):
params:
foo: "bar:"
"""
cfg = PluginConfig.parse_obj(yaml.safe_load(cfg_tag_group))
cfg = yaml.safe_load(cfg_tag_group)
c = dynamic_init_class(
cfg.run_config.grouping.cls, None, **cfg.run_config.grouping.params
cfg["run_config"]["grouping"]["cls"],
None,
**cfg["run_config"]["grouping"]["params"]
)
assert c is None
log_error.assert_called_once()

def test_plugin_config(self):
obj = yaml.safe_load(CONFIG_FULL)
cfg = PluginConfig.parse_obj(obj)
cfg = PluginConfig.model_validate(obj)
assert cfg.run_config.image == "gcr.io/project-image/test"
assert cfg.run_config.experiment_name == "Test Experiment"
assert cfg.run_config.experiment_description == "Test Experiment Description."
Expand All @@ -116,16 +119,16 @@ def test_plugin_config(self):
assert cfg.run_config.ttl == 300

def test_defaults(self):
cfg = PluginConfig.parse_obj(yaml.safe_load(CONFIG_MINIMAL))
cfg = PluginConfig.model_validate(yaml.safe_load(CONFIG_MINIMAL))
assert cfg.run_config.description is None
assert cfg.run_config.ttl == 3600 * 24 * 7

def test_missing_required_config(self):
with self.assertRaises(ValidationError):
PluginConfig.parse_obj({})
PluginConfig.model_validate({})

def test_resources_default_only(self):
cfg = PluginConfig.parse_obj(yaml.safe_load(CONFIG_MINIMAL))
cfg = PluginConfig.model_validate(yaml.safe_load(CONFIG_MINIMAL))
assert cfg.run_config.resources_for("node2") == {
"cpu": "500m",
"gpu": None,
Expand All @@ -138,7 +141,7 @@ def test_resources_default_only(self):
}

def test_node_selectors_default_only(self):
cfg = PluginConfig.parse_obj(yaml.safe_load(CONFIG_MINIMAL))
cfg = PluginConfig.model_validate(yaml.safe_load(CONFIG_MINIMAL))
assert cfg.run_config.node_selectors_for("node2") == {}
assert cfg.run_config.node_selectors_for("node3") == {}

Expand All @@ -147,7 +150,7 @@ def test_resources_no_default(self):
obj["run_config"].update(
{"resources": {"__default__": {"cpu": None, "memory": None}}}
)
cfg = PluginConfig.parse_obj(obj)
cfg = PluginConfig.model_validate(obj)
assert cfg.run_config.resources_for("node2") == {
"cpu": None,
"gpu": None,
Expand All @@ -164,7 +167,7 @@ def test_resources_default_and_node_specific(self):
}
}
)
cfg = PluginConfig.parse_obj(obj)
cfg = PluginConfig.model_validate(obj)
assert cfg.run_config.resources_for("node2") == {
"cpu": "100m",
"gpu": None,
Expand All @@ -186,7 +189,7 @@ def test_resources_default_and_tag_specific(self):
}
}
)
cfg = PluginConfig.parse_obj(obj)
cfg = PluginConfig.model_validate(obj)
assert cfg.run_config.resources_for("node2", {"tag1"}) == {
"cpu": "100m",
"gpu": "2",
Expand All @@ -209,7 +212,7 @@ def test_resources_node_and_tag_specific(self):
}
}
)
cfg = PluginConfig.parse_obj(obj)
cfg = PluginConfig.model_validate(obj)
assert cfg.run_config.resources_for("node2", {"tag1"}) == {
"cpu": "300m",
"gpu": "2",
Expand All @@ -226,7 +229,7 @@ def test_node_selectors_node_and_tag_specific(self):
}
}
)
cfg = PluginConfig.parse_obj(obj)
cfg = PluginConfig.model_validate(obj)
assert cfg.run_config.node_selectors_for("node2", {"tag1"}) == {
"cloud.google.com/gke-accelerator": "NVIDIA_TESLA_K80",
}
Expand All @@ -246,7 +249,7 @@ def test_parse_network_config(self):
}
}
)
cfg = PluginConfig.parse_obj(obj)
cfg = PluginConfig.model_validate(obj)
assert (
cfg.run_config.network.vpc
== "projects/some-project-id/global/networks/some-vpc-name"
Expand All @@ -255,15 +258,15 @@ def test_parse_network_config(self):
assert "mlflow.internal" in cfg.run_config.network.host_aliases[0].hostnames

def test_accept_default_vertex_ai_networking_config(self):
cfg = PluginConfig.parse_obj(yaml.safe_load(CONFIG_MINIMAL))
cfg = PluginConfig.model_validate(yaml.safe_load(CONFIG_MINIMAL))
assert cfg.run_config.network.vpc is None
assert cfg.run_config.network.host_aliases == []

@unittest.skip(
"Scheduling feature is temporarily disabled https://github.com/getindata/kedro-vertexai/issues/4"
)
def test_reuse_run_name_for_scheduled_run_name(self):
cfg = PluginConfig.parse_obj(
cfg = PluginConfig.model_validate(
{
"run_config": {
"scheduled_run_name": "some run",
Expand Down
5 changes: 3 additions & 2 deletions tests/test_config_providers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Test Dynamic Config Providers"""

import json
import logging
import unittest
from copy import deepcopy
Expand Down Expand Up @@ -59,7 +60,7 @@ def _get_test_config_with_dynamic_provider(
self,
class_name=MLFlowGoogleOAuthCredentialsProvider.full_name(),
) -> PluginConfig:
config_raw = deepcopy(test_config.dict())
config_raw = deepcopy(json.loads(test_config.model_dump_json()))
config_raw["run_config"]["dynamic_config_providers"] = [
{
"cls": class_name,
Expand All @@ -69,7 +70,7 @@ def _get_test_config_with_dynamic_provider(
},
}
]
config = PluginConfig.parse_obj(config_raw)
config = PluginConfig.model_validate(config_raw)
return config

def test_initialization_from_config(self):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_vertex_ai_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

class TestVertexAIClient(unittest.TestCase):
def create_client(self):
config = PluginConfig.parse_obj(
config = PluginConfig.model_validate(
{
"project_id": "PROJECT_ID",
"region": "REGION",
Expand Down
Loading

0 comments on commit 59075c0

Please sign in to comment.