Skip to content

Commit

Permalink
Initial commit of pyspark DataFrame backend (ibis-project#1)
Browse files Browse the repository at this point in the history
  • Loading branch information
icexelloss authored Jul 16, 2019
1 parent 3afa8b0 commit 94d8b0e
Show file tree
Hide file tree
Showing 6 changed files with 157 additions and 0 deletions.
3 changes: 3 additions & 0 deletions ibis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@
# pip install ibis-framework[spark]
import ibis.spark.api as spark # noqa: F401

with suppress(ImportError):
import ibis.pyspark.api as pyspark


def hdfs_connect(
host='localhost',
Expand Down
12 changes: 12 additions & 0 deletions ibis/pyspark/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from ibis.pyspark.client import PysparkClient


def connect(**kwargs):
"""
Create a `SparkClient` for use with Ibis. Pipes **kwargs into SparkClient,
which pipes them into SparkContext. See documentation for SparkContext:
https://spark.apache.org/docs/latest/api/python/_modules/pyspark/context.html#SparkContext
"""
client = PysparkClient(**kwargs)

return client
19 changes: 19 additions & 0 deletions ibis/pyspark/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from ibis.spark.client import SparkClient
from ibis.pyspark.operations import PysparkTable
from ibis.pyspark.compiler import translate

class PysparkClient(SparkClient):
"""
An ibis client that uses Pyspark SQL Dataframe
"""

dialect = None
table_class = PysparkTable

def compile(self, expr, *args, **kwargs):
"""Compile an ibis expression to a Pyspark DataFrame object
"""
return translate(expr)

def execute(self, df, params=None, limit='default', **kwargs):
return df.toPandas()
69 changes: 69 additions & 0 deletions ibis/pyspark/compiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import ibis.common as com
import ibis.sql.compiler as comp
import ibis.expr.operations as ops


from ibis.pyspark.operations import PysparkTable
from ibis.sql.compiler import Dialect

_operation_registry = {
}

class PysparkExprTranslator:
_registry = _operation_registry

@classmethod
def compiles(cls, klass):
def decorator(f):
cls._registry[klass] = f
return f

return decorator

def translate(self, expr):
# The operation node type the typed expression wraps
op = expr.op()

if type(op) in self._registry:
formatter = self._registry[type(op)]
return formatter(self, expr)
else:
raise com.OperationNotDefinedError(
'No translation rule for {}'.format(type(op))
)


class PysparkDialect(Dialect):
translator = PysparkExprTranslator


compiles = PysparkExprTranslator.compiles

@compiles(PysparkTable)
def compile_datasource(t, expr):
op = expr.op()
name, _, client = op.args
return client._session.table(name)

@compiles(ops.Selection)
def compile_selection(t, expr):
op = expr.op()
src_table = t.translate(op.selections[0])
for selection in op.selections[1:]:
column_name = selection.get_name()
column = t.translate(selection)
src_table = src_table.withColumn(column_name, column)

return src_table


@compiles(ops.TableColumn)
def compile_column(t, expr):
op = expr.op()
return t.translate(op.table)[op.name]


t = PysparkExprTranslator()

def translate(expr):
return t.translate(expr)
4 changes: 4 additions & 0 deletions ibis/pyspark/operations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
import ibis.expr.operations as ops

class PysparkTable(ops.DatabaseTable):
pass
50 changes: 50 additions & 0 deletions ibis/pyspark/tests/test_basic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import pandas as pd
import pandas.util.testing as tm
import pytest

import ibis

@pytest.fixture(scope='session')
def client():
client = ibis.pyspark.connect()
df = client._session.range(0, 10)
df.createTempView('table1')
return client


def test_basic(client):
table = client.table('table1')
result = table.compile().toPandas()
expected = pd.DataFrame({'id': range(0, 10)})

tm.assert_frame_equal(result, expected)


def test_projection(client):
import ipdb; ipdb.set_trace()
table = client.table('table1')
result1 = table.mutate(v=table['id']).compile().toPandas()

expected1 = pd.DataFrame(
{
'id': range(0, 10),
'v': range(0, 10)
}
)

result2 = table.mutate(v=table['id']).mutate(v2=table['id']).compile().toPandas()

expected2 = pd.DataFrame(
{
'id': range(0, 10),
'v': range(0, 10),
'v2': range(0, 10)
}
)

tm.assert_frame_equal(result1, expected1)
tm.assert_frame_equal(result2, expected2)


def test_udf(client):
table = client.table('table1')

0 comments on commit 94d8b0e

Please sign in to comment.