diff --git a/python/datafusion/tests/test_catalog.py b/python/datafusion/tests/test_catalog.py index 84aa801bc48fd..c68235a4f11fa 100644 --- a/python/datafusion/tests/test_catalog.py +++ b/python/datafusion/tests/test_catalog.py @@ -59,15 +59,17 @@ def test_basic(ctx, database): ctx.catalog("non-existent") default = ctx.catalog() - assert default.names() == ['public'] + assert default.names() == ["public"] - database = default.database('public') - assert database.names() == {'csv1', 'csv', 'csv2'} + database = default.database("public") + assert database.names() == {"csv1", "csv", "csv2"} - table = database.table('csv') + table = database.table("csv") assert table.kind == "physical" - assert table.schema == pa.schema([ - pa.field("int", pa.int64(), nullable=False), - pa.field("str", pa.string(), nullable=False), - pa.field("float", pa.float64(), nullable=False) - ]) + assert table.schema == pa.schema( + [ + pa.field("int", pa.int64(), nullable=False), + pa.field("str", pa.string(), nullable=False), + pa.field("float", pa.float64(), nullable=False), + ] + ) diff --git a/python/datafusion/tests/test_dataframe.py b/python/datafusion/tests/test_dataframe.py index b1b5c843ba112..543e8356cd9ae 100644 --- a/python/datafusion/tests/test_dataframe.py +++ b/python/datafusion/tests/test_dataframe.py @@ -18,7 +18,7 @@ import pyarrow as pa import pytest -from datafusion import ExecutionContext +from datafusion import ExecutionContext, DataFrame from datafusion import functions as f from . import generic as helpers @@ -64,7 +64,7 @@ def test_filter(df): def test_sort(df): - df = df.sort([f.col("b").sort(ascending=False)]) + df = df.sort(f.col("b").sort(ascending=False)) table = pa.Table.from_batches(df.collect()) expected = {"a": [3, 2, 1], "b": [6, 5, 4]} @@ -108,7 +108,7 @@ def test_join(): df1 = ctx.create_dataframe([[batch]]) df = df.join(df1, on="a", how="inner") - df = df.sort([f.col("a").sort(ascending=True)]) + df = df.sort(f.col("a").sort(ascending=True)) table = pa.Table.from_batches(df.collect()) expected = {"a": [1, 2], "c": [8, 10], "b": [4, 5]} @@ -132,3 +132,4 @@ def test_get_dataframe(tmp_path): ctx.register_csv("csv", path) df = ctx.table("csv") + assert isinstance(df, DataFrame)