diff --git a/dbt_pumpkin/plan.py b/dbt_pumpkin/plan.py index d68348d..c9380fb 100644 --- a/dbt_pumpkin/plan.py +++ b/dbt_pumpkin/plan.py @@ -3,6 +3,7 @@ import logging from abc import ABC, abstractmethod from dataclasses import dataclass +from enum import Enum from typing import TYPE_CHECKING from dbt_pumpkin.data import ResourceType @@ -208,6 +209,11 @@ def execute(self, files: dict[Path, dict]): yaml_columns.extend(reordered_columns) +class ExecutionMode(Enum): + RUN = "run" + DRY_RUN = "dry_run" + + class Plan: def __init__(self, actions: list[Action]): self.actions = actions @@ -215,7 +221,7 @@ def __init__(self, actions: list[Action]): def _affected_files(self) -> set[Path]: return {f for a in self.actions for f in a.affected_files()} - def execute(self, storage: Storage): + def execute(self, storage: Storage, mode: ExecutionMode): if not self.actions: logger.info("Nothing to do") return @@ -229,8 +235,9 @@ def execute(self, storage: Storage): logger.info("Action %s: %s", index + 1, action.describe()) action.execute(files) - logger.info("Persisting changes to files: %s", len(affected_files)) - storage.save_yaml(files) + if mode == ExecutionMode.RUN: + logger.info("Persisting changes to files: %s", len(affected_files)) + storage.save_yaml(files) def describe(self) -> str: return "\n".join(a.describe() for a in self.actions) diff --git a/dbt_pumpkin/pumpkin.py b/dbt_pumpkin/pumpkin.py index a430825..c0af70c 100644 --- a/dbt_pumpkin/pumpkin.py +++ b/dbt_pumpkin/pumpkin.py @@ -1,8 +1,10 @@ import logging +from typing import Callable from dbt_pumpkin.loader import ResourceLoader from dbt_pumpkin.params import ProjectParams, ResourceParams -from dbt_pumpkin.planner import BootstrapPlanner, RelocationPlanner, SynchronizationPlanner +from dbt_pumpkin.plan import ExecutionMode +from dbt_pumpkin.planner import ActionPlanner, BootstrapPlanner, RelocationPlanner, SynchronizationPlanner from dbt_pumpkin.storage import DiskStorage logger = logging.getLogger(__name__) @@ -13,33 +15,37 @@ def __init__(self, project_params: ProjectParams, resource_params: ResourceParam self.project_params = project_params self.resource_params = resource_params - def bootstrap(self, *, dry_run: bool): + def _execute(self, create_planner: Callable[[ResourceLoader], ActionPlanner], *, dry_run: bool): loader = ResourceLoader(self.project_params, self.resource_params) - resources = loader.select_resources() - planner = BootstrapPlanner(resources) + logger.debug("Creating action planner") + planner = create_planner(loader) plan = planner.plan() - storage = DiskStorage(loader.locate_project_dir(), loader.detect_yaml_format(), read_only=dry_run) - plan.execute(storage) + storage = DiskStorage(loader.locate_project_dir(), loader.detect_yaml_format()) + mode = ExecutionMode.DRY_RUN if dry_run else ExecutionMode.RUN - def relocate(self, *, dry_run: bool): - loader = ResourceLoader(self.project_params, self.resource_params) - resources = loader.select_resources() + logger.info("Plan execution mode: %s", mode) + plan.execute(storage, mode) - planner = RelocationPlanner(resources) - plan = planner.plan() + def bootstrap(self, *, dry_run: bool): + def create_planner(loader: ResourceLoader) -> ActionPlanner: + resources = loader.select_resources() + return BootstrapPlanner(resources) - storage = DiskStorage(loader.locate_project_dir(), loader.detect_yaml_format(), read_only=dry_run) - plan.execute(storage) + self._execute(create_planner, dry_run=dry_run) - def synchronize(self, *, dry_run: bool): - loader = ResourceLoader(self.project_params, self.resource_params) - resources = loader.select_resources() - tables = loader.lookup_tables() + def relocate(self, *, dry_run: bool): + def create_planner(loader: ResourceLoader) -> ActionPlanner: + resources = loader.select_resources() + return RelocationPlanner(resources) - planner = SynchronizationPlanner(resources, tables) - plan = planner.plan() + self._execute(create_planner, dry_run=dry_run) + + def synchronize(self, *, dry_run: bool): + def create_planner(loader: ResourceLoader) -> ActionPlanner: + resources = loader.select_resources() + tables = loader.lookup_tables() + return SynchronizationPlanner(resources, tables) - storage = DiskStorage(loader.locate_project_dir(), loader.detect_yaml_format(), read_only=dry_run) - plan.execute(storage) + self._execute(create_planner, dry_run=dry_run) diff --git a/dbt_pumpkin/storage.py b/dbt_pumpkin/storage.py index 0553b05..df80940 100644 --- a/dbt_pumpkin/storage.py +++ b/dbt_pumpkin/storage.py @@ -25,9 +25,8 @@ def save_yaml(self, files: dict[Path, any]): class DiskStorage(Storage): - def __init__(self, root_dir: Path, yaml_format: YamlFormat | None, *, read_only: bool): + def __init__(self, root_dir: Path, yaml_format: YamlFormat | None): self._root_dir = root_dir - self._read_only = read_only self._yaml = YAML(typ="rt") self._yaml.preserve_quotes = True @@ -51,12 +50,9 @@ def load_yaml(self, files: set[Path]) -> dict[Path, any]: return result def save_yaml(self, files: dict[Path, any]): - if self._read_only: - return - for file, content in files.items(): - resolved_file = self._root_dir / file - logger.debug("Saving file: %s", resolved_file) resolved_file = self._root_dir / file resolved_file.parent.mkdir(exist_ok=True) + + logger.debug("Saving file: %s", resolved_file) self._yaml.dump(content, resolved_file) diff --git a/tests/test_storage.py b/tests/test_storage.py index 11886e0..3eef2b4 100644 --- a/tests/test_storage.py +++ b/tests/test_storage.py @@ -17,13 +17,13 @@ def test_load_yaml(tmp_path: Path): """) ) - storage = DiskStorage(tmp_path, yaml_format=None, read_only=False) + storage = DiskStorage(tmp_path, yaml_format=None) files = storage.load_yaml({Path("schema.yml"), Path("absent.yml")}) assert files == {Path("schema.yml"): {"version": 2, "models": [{"name": "my_model"}]}} def test_save_yaml(tmp_path: Path): - storage = DiskStorage(tmp_path, yaml_format=None, read_only=False) + storage = DiskStorage(tmp_path, yaml_format=None) storage.save_yaml({Path("schema.yml"): {"version": 2, "models": [{"name": "my_other_model"}]}}) @@ -32,7 +32,7 @@ def test_save_yaml(tmp_path: Path): def test_save_yaml_default_format(tmp_path: Path): - storage = DiskStorage(tmp_path, yaml_format=None, read_only=False) + storage = DiskStorage(tmp_path, yaml_format=None) storage.save_yaml( { @@ -62,7 +62,7 @@ def test_save_yaml_default_format(tmp_path: Path): def test_save_yaml_format(tmp_path: Path): yaml_format = YamlFormat(indent=2, offset=2) - storage = DiskStorage(tmp_path, yaml_format, read_only=False) + storage = DiskStorage(tmp_path, yaml_format) storage.save_yaml( { @@ -90,13 +90,6 @@ def test_save_yaml_format(tmp_path: Path): assert actual == expected -def test_save_yaml_read_only(tmp_path: Path): - storage = DiskStorage(tmp_path, yaml_format=None, read_only=True) - storage.save_yaml({Path("schema.yml"): {"version": 2, "models": [{"name": "my_other_model"}]}}) - - assert not (tmp_path / "schema.yml").exists() - - def test_roundtrip_preserve_comments(tmp_path: Path): content = textwrap.dedent("""\ version: 2 @@ -112,7 +105,7 @@ def test_roundtrip_preserve_comments(tmp_path: Path): (tmp_path / "my_model.yml").write_text(content) - storage = DiskStorage(tmp_path, yaml_format=None, read_only=False) + storage = DiskStorage(tmp_path, yaml_format=None) files = storage.load_yaml({Path("my_model.yml")}) storage.save_yaml(files) @@ -153,7 +146,7 @@ def test_roundtrip_preserve_quotes(tmp_path: Path): (tmp_path / "my_model.yml").write_text(content) - storage = DiskStorage(tmp_path, yaml_format=None, read_only=False) + storage = DiskStorage(tmp_path, yaml_format=None) files = storage.load_yaml({Path("my_model.yml")}) yaml = files[Path("my_model.yml")]