diff --git a/tests/connect/test_limit_simple.py b/tests/connect/test_limit_simple.py new file mode 100644 index 0000000000..e5ab3b6257 --- /dev/null +++ b/tests/connect/test_limit_simple.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +# import time +import pytest +from pyspark.sql import SparkSession + + +@pytest.fixture +def spark_session(): + """Fixture to create and clean up a Spark session.""" + from daft.daft import connect_start + + # Start Daft Connect server + server = connect_start("sc://localhost:50051") + + # Initialize Spark Connect session + session = SparkSession.builder.appName("DaftConfigTest").remote("sc://localhost:50051").getOrCreate() + + yield session + + # Cleanup + server.shutdown() + session.stop() + # time.sleep(2) # Allow time for session cleanup + + +def test_range_first(spark_session): + spark_range = spark_session.range(10) + first_row = spark_range.first() + assert first_row["id"] == 0, "First row should have id=0" + + +def test_range_limit(spark_session): + spark_range = spark_session.range(10) + limited_df = spark_range.limit(5).toPandas() + assert len(limited_df) == 5, "Limited DataFrame should have 5 rows" + assert list(limited_df["id"]) == list(range(5)), "Limited DataFrame should contain values 0-4" diff --git a/tests/connect/test_range_simple.py b/tests/connect/test_range_simple.py index 34e82ebbcf..b277d38481 100644 --- a/tests/connect/test_range_simple.py +++ b/tests/connect/test_range_simple.py @@ -12,16 +12,3 @@ def test_range_operation(spark_session): # Verify the DataFrame has expected values assert len(pandas_df) == 10, "DataFrame should have 10 rows" assert list(pandas_df["id"]) == list(range(10)), "DataFrame should contain values 0-9" - - -def test_range_first(spark_session): - spark_range = spark_session.range(10) - first_row = spark_range.first() - assert first_row["id"] == 0, "First row should have id=0" - - -def test_range_limit(spark_session): - spark_range = spark_session.range(10) - limited_df = spark_range.limit(5).toPandas() - assert len(limited_df) == 5, "Limited DataFrame should have 5 rows" - assert list(limited_df["id"]) == list(range(5)), "Limited DataFrame should contain values 0-4"