diff --git a/.changes/unreleased/Fixes-20240715-205355.yaml b/.changes/unreleased/Fixes-20240715-205355.yaml new file mode 100644 index 00000000..780a6ad9 --- /dev/null +++ b/.changes/unreleased/Fixes-20240715-205355.yaml @@ -0,0 +1,6 @@ +kind: Fixes +body: Fix case-insensitive env vars for Windows +time: 2024-07-15T20:53:55.946355+01:00 +custom: + Author: peterallenwebb aranke + Issue: "166" diff --git a/dbt_common/context.py b/dbt_common/context.py index d1775c55..947d409a 100644 --- a/dbt_common/context.py +++ b/dbt_common/context.py @@ -1,15 +1,46 @@ +import os from contextvars import ContextVar, copy_context -from typing import List, Mapping, Optional +from typing import List, Mapping, Optional, Iterator from dbt_common.constants import PRIVATE_ENV_PREFIX, SECRET_ENV_PREFIX from dbt_common.record import Recorder +class CaseInsensitiveMapping(Mapping): + def __init__(self, env: Mapping[str, str]): + self._env = {k.casefold(): (k, v) for k, v in env.items()} + + def __getitem__(self, key: str) -> str: + return self._env[key.casefold()][1] + + def __len__(self) -> int: + return len(self._env) + + def __iter__(self) -> Iterator[str]: + for item in self._env.items(): + yield item[0] + + class InvocationContext: def __init__(self, env: Mapping[str, str]): - self._env = {k: v for k, v in env.items() if not k.startswith(PRIVATE_ENV_PREFIX)} + self._env: Mapping[str, str] + + env_public = {} + env_private = {} + + for k, v in env.items(): + if k.startswith(PRIVATE_ENV_PREFIX): + env_private[k] = v + else: + env_public[k] = v + + if os.name == "nt": + self._env = CaseInsensitiveMapping(env_public) + else: + self._env = env_public + self._env_secrets: Optional[List[str]] = None - self._env_private = {k: v for k, v in env.items() if k.startswith(PRIVATE_ENV_PREFIX)} + self._env_private = env_private self.recorder: Optional[Recorder] = None # This class will also eventually manage the invocation_id, flags, event manager, etc. diff --git a/pyproject.toml b/pyproject.toml index 64fc04fc..ba306437 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -126,6 +126,7 @@ ignore = ["E203", "E501", "E741", "W503", "W504"] exclude = [ "dbt_common/events/types_pb2.py", "venv", + ".venv", "env*" ] per-file-ignores = ["*/__init__.py: F401"] diff --git a/tests/unit/test_invocation_context.py b/tests/unit/test_invocation_context.py index 3dc832d3..fbf060ba 100644 --- a/tests/unit/test_invocation_context.py +++ b/tests/unit/test_invocation_context.py @@ -1,5 +1,9 @@ +import os + +import pytest + from dbt_common.constants import PRIVATE_ENV_PREFIX, SECRET_ENV_PREFIX -from dbt_common.context import InvocationContext +from dbt_common.context import InvocationContext, CaseInsensitiveMapping def test_invocation_context_env() -> None: @@ -8,6 +12,17 @@ def test_invocation_context_env() -> None: assert ic.env == test_env +@pytest.mark.skipif( + os.name != "nt", reason="Test for case-insensitive env vars, only run on Windows" +) +def test_invocation_context_windows() -> None: + test_env = {"var_1": "lowercase", "vAr_2": "mixedcase", "VAR_3": "uppercase"} + ic = InvocationContext(env=test_env) + assert ic.env == CaseInsensitiveMapping( + {"var_1": "lowercase", "var_2": "mixedcase", "var_3": "uppercase"} + ) + + def test_invocation_context_secrets() -> None: test_env = { f"{SECRET_ENV_PREFIX}_VAR_1": "secret1",