Skip to content

Commit

Permalink
fix: More style fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
tomasfarias committed Feb 19, 2023
1 parent 0e754ef commit 19976f9
Show file tree
Hide file tree
Showing 9 changed files with 51 additions and 29 deletions.
14 changes: 12 additions & 2 deletions airflow_dbt_python/hooks/dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,17 @@
from contextlib import contextmanager
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import TYPE_CHECKING, Any, Iterable, Iterator, NamedTuple, Optional, Union
from typing import (
TYPE_CHECKING,
Any,
Dict,
Iterable,
Iterator,
NamedTuple,
Optional,
Tuple,
Union,
)
from urllib.parse import urlparse

from airflow.exceptions import AirflowException
Expand All @@ -21,7 +31,7 @@
from airflow_dbt_python.utils.configs import BaseConfig
from airflow_dbt_python.utils.url import URLLike

DbtRemoteHooksDict = dict[tuple[str, Optional[str]], DbtRemoteHook]
DbtRemoteHooksDict = Dict[Tuple[str, Optional[str]], DbtRemoteHook]


class DbtTaskResult(NamedTuple):
Expand Down
4 changes: 2 additions & 2 deletions airflow_dbt_python/hooks/git.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""A concrete DbtRemoteHook for git repositories with dulwich."""
import datetime as dt
from typing import Callable, Optional, Union
from typing import Callable, Optional, Tuple, Union

from airflow.providers.ssh.hooks.ssh import SSHHook
from dulwich.client import HttpGitClient, SSHGitClient, TCPGitClient
Expand Down Expand Up @@ -155,7 +155,7 @@ def download(

client.clone(path, str(destination), mkdir=not destination.exists())

def get_git_client_path(self, url: URL) -> tuple[GitClients, str]:
def get_git_client_path(self, url: URL) -> Tuple[GitClients, str]:
"""Initialize a dulwich git client according to given URL's scheme."""
if url.scheme == "git":
client: GitClients = TCPGitClient(url.hostname, url.port)
Expand Down
16 changes: 5 additions & 11 deletions airflow_dbt_python/operators/dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,18 @@
import datetime as dt
import os
from dataclasses import asdict, is_dataclass
from itertools import chain
from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional, Union

from airflow.exceptions import AirflowException
from airflow.models.baseoperator import BaseOperator
from airflow.models.xcom import XCOM_RETURN_KEY

from airflow_dbt_python.utils.enums import LogFormat, Output

if TYPE_CHECKING:
from dbt.contracts.results import RunResult

from airflow_dbt_python.hooks.dbt import DbtHook, DbtTaskResult
from airflow_dbt_python.utils.enums import LogFormat, Output

base_template_fields = [
"project_dir",
Expand Down Expand Up @@ -71,7 +69,7 @@ def __init__(
warn_error: Optional[bool] = None,
# Logging
debug: Optional[bool] = None,
log_format: Optional[str] = None,
log_format: Optional[LogFormat] = None,
log_cache_events: Optional[bool] = False,
quiet: Optional[bool] = None,
no_print: Optional[bool] = None,
Expand Down Expand Up @@ -119,9 +117,7 @@ def __init__(
self.log_cache_events = log_cache_events
self.quiet = quiet
self.no_print = no_print
self.log_format = (
LogFormat.from_str(log_format) if log_format is not None else None
)
self.log_format = log_format
self.record_timing_info = record_timing_info

self.dbt_defer = defer
Expand Down Expand Up @@ -508,7 +504,7 @@ def __init__(
select: Optional[list[str]] = None,
exclude: Optional[list[str]] = None,
selector_name: Optional[str] = None,
dbt_output: Optional[str] = None,
dbt_output: Optional[Output] = None,
output_keys: Optional[list[str]] = None,
indirect_selection: Optional[str] = None,
**kwargs,
Expand All @@ -518,9 +514,7 @@ def __init__(
self.select = select
self.exclude = exclude
self.selector_name = selector_name
self.dbt_output = (
Output.from_str(dbt_output) if dbt_output is not None else None
)
self.dbt_output = dbt_output
self.output_keys = output_keys
self.indirect_selection = indirect_selection

Expand Down
4 changes: 2 additions & 2 deletions airflow_dbt_python/utils/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import os
import pickle
from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional, Type, TypeVar, Union
from typing import TYPE_CHECKING, Any, Optional, Type, Union

import dbt.flags as flags
import yaml
Expand Down Expand Up @@ -492,7 +492,7 @@ class ListTaskConfig(SelectionConfig):

cls: Type[BaseTask] = dataclasses.field(default=ListTask, init=False)
indirect_selection: Optional[str] = None
output: Output = Output.SELECTOR
output: Output = Output["selector"]
output_keys: Optional[list[str]] = None
resource_types: Optional[list[str]] = None
which: str = dataclasses.field(default="list", init=False)
Expand Down
4 changes: 2 additions & 2 deletions airflow_dbt_python/utils/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@ def __eq__(self, other) -> bool:
return Enum.__eq__(self, other)


class LogFormat(FromStrEnum):
class LogFormat(str, Enum):
"""Allowed dbt log formats."""

DEFAULT = "default"
JSON = "json"
TEXT = "text"


class Output(FromStrEnum):
class Output(str, Enum):
"""Allowed output arguments."""

JSON = "json"
Expand Down
24 changes: 19 additions & 5 deletions airflow_dbt_python/utils/url.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,18 @@ def is_relative_to(self, base: Union[str, "URL"]) -> bool:
>>> URL("s3://test-s3-bucket/project/data/").is_relative_to("different")
False
"""
if isinstance(base, URL):
return self.path.is_relative_to(base.path)
else:
return self.path.is_relative_to(base)
is_relative = True
try:
# is_relative_to was added in Python 3.9 and we have to support 3.7 and 3.8.
if isinstance(base, URL):
self.path.relative_to(base.path)
else:
self.path.relative_to(base)

except ValueError:
is_relative = False

return is_relative

def join(self, relative: str) -> "URL":
"""Return a new URL by joining this with relative.
Expand Down Expand Up @@ -279,7 +287,13 @@ def unlink(self, missing_ok: bool = False) -> None:
if self.is_local() is False:
raise ValueError("Cannot unlink remote file.")

self.path.unlink(missing_ok)
try:
self.path.unlink()
except FileNotFoundError:
# In python 3.8, the missing_ok parameter was added to ignore these
# exceptions. Once we drop Python 3.7 support, we can remove this block.
if missing_ok:
raise

def mkdir(self, parents: bool = False, exist_ok: bool = False) -> None:
"""Call this URL's underlying Path's mkdir."""
Expand Down
5 changes: 4 additions & 1 deletion tests/hooks/dbt/test_dbt_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@ def test_dbt_build_task(
assert len(result.run_results) == 12

for run_result in result.run_results:
assert run_result.status == RunStatus.Success or run_result.status == TestStatus.Pass
assert (
run_result.status == RunStatus.Success
or run_result.status == TestStatus.Pass
)


def test_dbt_build_task_non_existent_model(
Expand Down
6 changes: 3 additions & 3 deletions tests/utils/test_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def test_base_config_create_dbt_profile_with_extra_target(
project_dir=dbt_project_file.parent,
profiles_dir=profiles_file.parent,
)
extra_target = hook.get_target_from_connection(conn_id)
extra_target = hook.get_dbt_target_from_connection(conn_id)

profile = config.create_dbt_profile(extra_target)
assert profile.profile_name == "default"
Expand All @@ -281,7 +281,7 @@ def test_base_config_create_dbt_profile_with_extra_target_no_profile(
config = BaseConfig(
target=conn_id, project_dir=dbt_project_file.parent, profiles_dir=None
)
extra_target = hook.get_target_from_connection(conn_id)
extra_target = hook.get_dbt_target_from_connection(conn_id)

profile = config.create_dbt_profile(extra_target)
assert profile.profile_name == "default"
Expand Down Expand Up @@ -356,7 +356,7 @@ def test_base_config_create_dbt_project_and_profile_with_no_profile(
for conn_id in airflow_conns:
config.target = conn_id

extra_target = hook.get_target_from_connection(conn_id)
extra_target = hook.get_dbt_target_from_connection(conn_id)
project, profile = config.create_dbt_project_and_profile(extra_target)

assert project.model_paths == ["models"]
Expand Down
3 changes: 2 additions & 1 deletion tests/utils/test_url.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Unit test module for URL utility."""
from tarfile import TarFile
from typing import Dict
from zipfile import ZipFile

import pytest
Expand Down Expand Up @@ -63,7 +64,7 @@ def test_url_initialize(urllike: URLLike, expected: bool):
),
),
)
def test_url_initialize_from_parts(parts: dict[str, str], expected: URL):
def test_url_initialize_from_parts(parts: Dict[str, str], expected: URL):
"""Test parsing of URLs during initialization."""
result = URL.from_parts(**parts)
assert result == expected
Expand Down

0 comments on commit 19976f9

Please sign in to comment.