Skip to content

Commit

Permalink
separate tests
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewgazelka committed Nov 20, 2024
1 parent 70a9635 commit ebc956e
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 13 deletions.
37 changes: 37 additions & 0 deletions tests/connect/test_limit_simple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from __future__ import annotations

# import time
import pytest
from pyspark.sql import SparkSession


@pytest.fixture
def spark_session():
"""Fixture to create and clean up a Spark session."""
from daft.daft import connect_start

# Start Daft Connect server
server = connect_start("sc://localhost:50051")

# Initialize Spark Connect session
session = SparkSession.builder.appName("DaftConfigTest").remote("sc://localhost:50051").getOrCreate()

yield session

# Cleanup
server.shutdown()
session.stop()
# time.sleep(2) # Allow time for session cleanup


def test_range_first(spark_session):
spark_range = spark_session.range(10)
first_row = spark_range.first()
assert first_row["id"] == 0, "First row should have id=0"


def test_range_limit(spark_session):
spark_range = spark_session.range(10)
limited_df = spark_range.limit(5).toPandas()
assert len(limited_df) == 5, "Limited DataFrame should have 5 rows"
assert list(limited_df["id"]) == list(range(5)), "Limited DataFrame should contain values 0-4"
13 changes: 0 additions & 13 deletions tests/connect/test_range_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,3 @@ def test_range_operation(spark_session):
# Verify the DataFrame has expected values
assert len(pandas_df) == 10, "DataFrame should have 10 rows"
assert list(pandas_df["id"]) == list(range(10)), "DataFrame should contain values 0-9"


def test_range_first(spark_session):
spark_range = spark_session.range(10)
first_row = spark_range.first()
assert first_row["id"] == 0, "First row should have id=0"


def test_range_limit(spark_session):
spark_range = spark_session.range(10)
limited_df = spark_range.limit(5).toPandas()
assert len(limited_df) == 5, "Limited DataFrame should have 5 rows"
assert list(limited_df["id"]) == list(range(5)), "Limited DataFrame should contain values 0-4"

0 comments on commit ebc956e

Please sign in to comment.