Skip to content

Commit

Permalink
feat(pyspark): support ibis.pyspark.connect() (#8515)
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud authored Mar 1, 2024
1 parent ef39aab commit 0f663e6
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 1 deletion.
7 changes: 6 additions & 1 deletion ibis/backends/pyspark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._cached_dataframes = {}

def do_connect(self, session: SparkSession) -> None:
def do_connect(self, session: SparkSession | None = None) -> None:
"""Create a PySpark `Backend` for use with Ibis.
Parameters
Expand All @@ -147,6 +147,11 @@ def do_connect(self, session: SparkSession) -> None:
<ibis.backends.pyspark.Backend at 0x...>
"""
if session is None:
from pyspark.sql import SparkSession

session = SparkSession.builder.getOrCreate()

self._context = session.sparkContext
self._session = session

Expand Down
6 changes: 6 additions & 0 deletions ibis/backends/pyspark/tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,9 @@ def test_string_literal_backslash_escaping(con):
expr = ibis.literal("\\d\\e")
result = con.execute(expr)
assert result == "\\d\\e"


def test_connect_without_explicit_session():
con = ibis.pyspark.connect()
result = con.sql("SELECT CAST(1 AS BIGINT) as foo").to_pandas()
tm.assert_frame_equal(result, pd.DataFrame({"foo": [1]}))

0 comments on commit 0f663e6

Please sign in to comment.