Skip to content

Commit

Permalink
copy test updates from python3.13 branch
Browse files Browse the repository at this point in the history
  • Loading branch information
sh-rp committed Jan 8, 2025
1 parent 602ddeb commit 5e4f829
Show file tree
Hide file tree
Showing 10 changed files with 55 additions and 30 deletions.
2 changes: 1 addition & 1 deletion dlt/common/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ def is_typeddict(t: Type[Any]) -> bool:

def is_annotated(ann_type: Any) -> bool:
try:
return issubclass(get_origin(ann_type), Annotated) # type: ignore[arg-type]
return get_origin(ann_type) is Annotated
except TypeError:
return False

Expand Down
7 changes: 3 additions & 4 deletions dlt/destinations/impl/databricks/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@
from dlt.common.schema import TColumnSchema, Schema
from dlt.common.schema.typing import TColumnType
from dlt.common.storages import FilesystemConfiguration, fsspec_from_config

from dlt.destinations.insert_job_client import InsertValuesJobClient
from dlt.destinations.job_client_impl import SqlJobClientWithStagingDataset
from dlt.destinations.exceptions import LoadJobTerminalException
from dlt.destinations.impl.databricks.configuration import DatabricksClientConfiguration
from dlt.destinations.impl.databricks.sql_client import DatabricksSqlClient
Expand Down Expand Up @@ -198,7 +197,7 @@ def gen_delete_from_sql(
"""


class DatabricksClient(InsertValuesJobClient, SupportsStagingDestination):
class DatabricksClient(SqlJobClientWithStagingDataset, SupportsStagingDestination):
def __init__(
self,
schema: Schema,
Expand All @@ -213,7 +212,7 @@ def __init__(
)
super().__init__(schema, config, sql_client)
self.config: DatabricksClientConfiguration = config
self.sql_client: DatabricksSqlClient = sql_client
self.sql_client: DatabricksSqlClient = sql_client # type: ignore[assignment, unused-ignore]
self.type_mapper = self.capabilities.get_type_mapper()

def create_load_job(
Expand Down
8 changes: 4 additions & 4 deletions dlt/destinations/impl/databricks/sql_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
)

from databricks.sdk.core import Config, oauth_service_principal
from databricks import sql as databricks_lib # type: ignore[attr-defined]
from databricks import sql as databricks_lib
from databricks.sql.client import (
Connection as DatabricksSqlConnection,
Cursor as DatabricksSqlCursor,
Expand All @@ -43,7 +43,7 @@
class DatabricksCursorImpl(DBApiCursorImpl):
"""Use native data frame support if available"""

native_cursor: DatabricksSqlCursor
native_cursor: DatabricksSqlCursor # type: ignore[assignment, unused-ignore]
vector_size: ClassVar[int] = 2048 # vector size is 2048

def iter_arrow(self, chunk_size: int) -> Generator[ArrowTable, None, None]:
Expand Down Expand Up @@ -140,7 +140,6 @@ def execute_sql(
@contextmanager
@raise_database_error
def execute_query(self, query: AnyStr, *args: Any, **kwargs: Any) -> Iterator[DBApiCursor]:
curr: DBApiCursor
# TODO: Inline param support will be dropped in future databricks driver, switch to :named paramstyle
# This will drop support for cluster runtime v13.x
# db_args: Optional[Dict[str, Any]]
Expand All @@ -159,10 +158,11 @@ def execute_query(self, query: AnyStr, *args: Any, **kwargs: Any) -> Iterator[DB
# else:
# db_args = kwargs or None

assert isinstance(query, str)
db_args = args or kwargs or None
with self._conn.cursor() as curr:
curr.execute(query, db_args)
yield DatabricksCursorImpl(curr) # type: ignore[abstract]
yield DatabricksCursorImpl(curr) # type: ignore[arg-type, abstract, unused-ignore]

def catalog_name(self, escape: bool = True) -> Optional[str]:
catalog = self.capabilities.casefold_identifier(self.credentials.catalog)
Expand Down
4 changes: 2 additions & 2 deletions dlt/destinations/impl/lancedb/lancedb_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,7 @@ def update_schema_in_storage(self) -> None:
self.schema.naming.normalize_identifier(
"engine_version"
): self.schema.ENGINE_VERSION,
self.schema.naming.normalize_identifier("inserted_at"): str(pendulum.now()),
self.schema.naming.normalize_identifier("inserted_at"): pendulum.now(),
self.schema.naming.normalize_identifier("schema_name"): self.schema.name,
self.schema.naming.normalize_identifier(
"version_hash"
Expand Down Expand Up @@ -693,7 +693,7 @@ def complete_load(self, load_id: str) -> None:
self.schema.naming.normalize_identifier("load_id"): load_id,
self.schema.naming.normalize_identifier("schema_name"): self.schema.name,
self.schema.naming.normalize_identifier("status"): 0,
self.schema.naming.normalize_identifier("inserted_at"): str(pendulum.now()),
self.schema.naming.normalize_identifier("inserted_at"): pendulum.now(),
self.schema.naming.normalize_identifier("schema_version_hash"): None,
}
]
Expand Down
9 changes: 7 additions & 2 deletions dlt/extract/incremental/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,15 @@

try:
from dlt.common.libs import pyarrow
from dlt.common.libs.numpy import numpy
from dlt.common.libs.pyarrow import pyarrow as pa, TAnyArrowItem
from dlt.common.libs.pyarrow import from_arrow_scalar, to_arrow_scalar
except MissingDependencyException:
pa = None
pyarrow = None

try:
from dlt.common.libs.numpy import numpy
except MissingDependencyException:
numpy = None

# NOTE: always import pandas independently from pyarrow
Expand Down Expand Up @@ -320,7 +323,9 @@ def _add_unique_index(self, tbl: "pa.Table") -> "pa.Table":
"""Creates unique index if necessary."""
# create unique index if necessary
if self._dlt_index not in tbl.schema.names:
tbl = pyarrow.append_column(tbl, self._dlt_index, pa.array(numpy.arange(tbl.num_rows)))
# indices = pa.compute.sequence(start=0, step=1, length=tbl.num_rows,
indices = pa.array(range(tbl.num_rows))
tbl = pyarrow.append_column(tbl, self._dlt_index, indices)
return tbl

def __call__(
Expand Down
13 changes: 9 additions & 4 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,6 @@ disallow_untyped_defs=false
[mypy-jsonpath_ng.*]
ignore_missing_imports=true

[mypy-astunparse.*]
ignore_missing_imports=true

[mypy-google.oauth2.*]
ignore_missing_imports=true

Expand Down Expand Up @@ -89,6 +86,9 @@ ignore_missing_imports=true
[mypy-pandas.*]
ignore_missing_imports=true

[mypy-numpy.*]
ignore_missing_imports=true

[mypy-apiclient.*]
ignore_missing_imports=true

Expand All @@ -101,8 +101,10 @@ ignore_missing_imports=true

[mypy-connectorx]
ignore_missing_imports=true

[mypy-s3fs.*]
ignore_missing_imports=true

[mypy-win_precise_time]
ignore_missing_imports=true

Expand All @@ -121,6 +123,9 @@ ignore_missing_imports = True
[mypy-pytz.*]
ignore_missing_imports = True

[mypy-sentry_sdk.*]
ignore_missing_imports = True

[mypy-tornado.*]
ignore_missing_imports = True

Expand All @@ -130,7 +135,7 @@ ignore_missing_imports = True
[mypy-snowflake.*]
ignore_missing_imports = True

[mypy-backports.*]
[mypy-pendulum.*]
ignore_missing_imports = True

[mypy-time_machine.*]
Expand Down
4 changes: 2 additions & 2 deletions tests/extract/test_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -1447,11 +1447,11 @@ def some_tx(item):
def some_tx_func(item):
return list(range(item))

transformer = dlt.transformer(some_tx_func, parallelized=True, data_from=resource)
transformer = dlt.transformer(some_tx_func, data_from=resource)
pipe_gen = transformer._pipe.gen
inner = pipe_gen(3) # type: ignore
# this is a regular function returning list
assert inner() == [0, 1, 2] # type: ignore[operator]
assert inner == [0, 1, 2]
assert list(transformer) == [0, 0, 1, 0, 1, 2]

# Invalid parallel resources
Expand Down
10 changes: 7 additions & 3 deletions tests/libs/test_parquet_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest
import datetime # noqa: 251
import time
import math

from dlt.common import pendulum, Decimal, json
from dlt.common.configuration import inject_section
Expand All @@ -12,7 +13,6 @@
from dlt.common.schema.utils import new_column
from dlt.common.configuration.specs.config_section_context import ConfigSectionContext
from dlt.common.time import ensure_pendulum_datetime
from dlt.common.libs.pyarrow import from_arrow_scalar

from tests.common.data_writers.utils import get_writer
from tests.cases import (
Expand Down Expand Up @@ -165,10 +165,14 @@ def test_parquet_writer_size_file_rotation() -> None:
for i in range(0, 100):
writer.write_data_item([{"col1": i}], columns)

assert len(writer.closed_files) == 25
# different arrow version create different file sizes
no_files = len(writer.closed_files)
i_per_file = int(math.ceil(100 / no_files))
assert no_files >= 17 and no_files <= 25

with open(writer.closed_files[4].file_path, "rb") as f:
table = pq.read_table(f)
assert table.column("col1").to_pylist() == list(range(16, 20))
assert table.column("col1").to_pylist() == list(range(4 * i_per_file, 5 * i_per_file))


def test_parquet_writer_config() -> None:
Expand Down
4 changes: 2 additions & 2 deletions tests/libs/test_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ def test_nested_model_config_propagation() -> None:
assert model_freeze.__fields__["address"].annotation.__name__ == "UserAddressExtraAllow" # type: ignore[index]
# annotated is preserved
type_origin = get_origin(model_freeze.__fields__["address"].rebuild_annotation()) # type: ignore[index]
assert issubclass(type_origin, Annotated) # type: ignore[arg-type]
assert type_origin is Annotated
# UserAddress is converted to UserAddressAllow only once
type_annotation = model_freeze.__fields__["address"].annotation # type: ignore[index]
assert type_annotation is get_args(model_freeze.__fields__["unity"].annotation)[0] # type: ignore[index]
Expand Down Expand Up @@ -404,7 +404,7 @@ class UserPipe(BaseModel):
assert model_freeze.__fields__["address"].annotation.__name__ == "UserAddressPipeExtraAllow" # type: ignore[index]
# annotated is preserved
type_origin = get_origin(model_freeze.__fields__["address"].rebuild_annotation()) # type: ignore[index]
assert issubclass(type_origin, Annotated) # type: ignore[arg-type]
assert type_origin is Annotated
# UserAddress is converted to UserAddressAllow only once
type_annotation = model_freeze.__fields__["address"].annotation # type: ignore[index]
assert type_annotation is get_args(model_freeze.__fields__["unity"].annotation)[0] # type: ignore[index]
Expand Down
24 changes: 18 additions & 6 deletions tests/pipeline/test_pipeline_extra.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,8 +450,8 @@ class Parent(BaseModel):


@pytest.mark.skipif(
importlib.util.find_spec("pandas") is not None,
reason="Test skipped because pandas IS installed",
importlib.util.find_spec("pandas") is not None or importlib.util.find_spec("numpy") is not None,
reason="Test skipped because pandas or numpy ARE installed",
)
def test_arrow_no_pandas() -> None:
import pyarrow as pa
Expand All @@ -461,20 +461,32 @@ def test_arrow_no_pandas() -> None:
"Strings": ["apple", "banana", "cherry", "date", "elderberry"],
}

df = pa.table(data)
table = pa.table(data)

@dlt.resource
def pandas_incremental(numbers=dlt.sources.incremental("Numbers")):
yield df
yield table

info = dlt.run(
pandas_incremental(), write_disposition="merge", table_name="data", destination="duckdb"
)

# change table
data = {
"Numbers": [5, 6],
"Strings": ["elderberry", "burak"],
}

table = pa.table(data)

info = dlt.run(
pandas_incremental(), write_disposition="append", table_name="data", destination="duckdb"
pandas_incremental(), write_disposition="merge", table_name="data", destination="duckdb"
)

with info.pipeline.sql_client() as client: # type: ignore
with client.execute_query("SELECT * FROM data") as c:
with pytest.raises(ImportError):
df = c.df()
c.df()


def test_empty_parquet(test_storage: FileStorage) -> None:
Expand Down

0 comments on commit 5e4f829

Please sign in to comment.