Skip to content

Commit

Permalink
use parametrized fixtures
Browse files Browse the repository at this point in the history
  • Loading branch information
0xbe7a committed Jul 29, 2023
1 parent 7928c5c commit 7c7fc7a
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 159 deletions.
8 changes: 5 additions & 3 deletions polarify/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@
def inline_all(expr: ast.expr, assignments: assignments) -> ast.expr:
assignments = copy(assignments)
if isinstance(expr, ast.Name):
if expr.id not in assignments:
raise ValueError(f"Variable {expr.id} not defined")
return inline_all(assignments[expr.id], assignments)
if expr.id in assignments:
return inline_all(assignments[expr.id], assignments)
else:
return expr

elif isinstance(expr, ast.BinOp):
expr.left = inline_all(expr.left, assignments)
expr.right = inline_all(expr.right, assignments)
Expand Down
185 changes: 29 additions & 156 deletions tests/test_parse_body.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import polars as pl
import pytest
from hypothesis import given
from polars.testing import assert_frame_equal
from polars.testing.parametric import column, dataframes

from polarify import polarify


@polarify
def transformed_signum(x: pl.Expr) -> pl.Expr:
def signum(x):
s = 0
if x > 0:
s = 1
Expand All @@ -16,71 +16,18 @@ def transformed_signum(x: pl.Expr) -> pl.Expr:
return s


def literal_signum(x: int) -> int:
s = 0
if x > 0:
s = 1
elif x < 0:
s = -1
return s


@given(df=dataframes(column("x", dtype=pl.Int8), min_size=1))
def test_transform_signum(df: pl.DataFrame):
x = pl.col("x")
assert_frame_equal(
df.select(transformed_signum(x).alias("apply")),
df.apply(lambda r: literal_signum(r[0])),
check_dtype=False,
)


@polarify
def transformed_early_return(x: pl.Expr) -> pl.Expr:
def early_return(x):
if x > 0:
return 1
return 0


def literal_early_return(x: int) -> int:
if x > 0:
return 1
return 0


@given(df=dataframes(column("x", dtype=pl.Int8), min_size=1))
def test_transform_early_return(df: pl.DataFrame):
x = pl.col("x")
assert_frame_equal(
df.select(transformed_early_return(x).alias("apply")),
df.apply(lambda r: literal_early_return(r[0])),
check_dtype=False,
)


@polarify
def transformed_assign_both_branches(x: pl.Expr) -> pl.Expr:
def assign_both_branches(x):
s = 1 if x > 0 else -1
return s


def literal_assign_both_branches(x: int) -> int:
s = 1 if x > 0 else -1
return s


@given(df=dataframes(column("x", dtype=pl.Int8), min_size=1))
def test_transform_assign_both_branches(df: pl.DataFrame):
x = pl.col("x")
assert_frame_equal(
df.select(transformed_assign_both_branches(x).alias("apply")),
df.apply(lambda r: literal_assign_both_branches(r[0])),
check_dtype=False,
)


@polarify
def transformed_multiple_if_else(x: pl.Expr) -> pl.Expr:
def multiple_if_else(x):
if x > 0:
s = 1
elif x < 0:
Expand All @@ -90,38 +37,7 @@ def transformed_multiple_if_else(x: pl.Expr) -> pl.Expr:
return s


def literal_multiple_if_else(x: int) -> int:
if x > 0:
s = 1
elif x < 0:
s = -1
else:
s = 0
return s


@given(df=dataframes(column("x", dtype=pl.Int8), min_size=1))
def test_transform_multiple_if_else(df: pl.DataFrame):
x = pl.col("x")
assert_frame_equal(
df.select(transformed_multiple_if_else(x).alias("apply")),
df.apply(lambda r: literal_multiple_if_else(r[0])),
check_dtype=False,
)


@polarify
def transformed_nested_if_else(x: pl.Expr) -> pl.Expr:
if x > 0:
s = 2 if x > 1 else 1
elif x < 0:
s = -1
else:
s = 0
return s


def literal_nested_if_else(x: int) -> int:
def nested_if_else(x):
if x > 0:
s = 2 if x > 1 else 1
elif x < 0:
Expand All @@ -131,32 +47,7 @@ def literal_nested_if_else(x: int) -> int:
return s


@given(df=dataframes(column("x", dtype=pl.Int8), min_size=1))
def test_transform_nested_if_else(df: pl.DataFrame):
x = pl.col("x")
assert_frame_equal(
df.select(transformed_nested_if_else(x).alias("apply")),
df.apply(lambda r: literal_nested_if_else(r[0])),
check_dtype=False,
)


@polarify
def transform_assignments_inside_branch(x: pl.Expr) -> pl.Expr:
if x > 0:
s = 1
s = s + 1
s = x * s
elif x < 0:
s = -1
s = s - 1
s = x
else:
s = 0
return s


def literal_assignments_inside_branch(x: int) -> int:
def assignments_inside_branch(x):
if x > 0:
s = 1
s = s + 1
Expand All @@ -170,61 +61,43 @@ def literal_assignments_inside_branch(x: int) -> int:
return s


@given(df=dataframes(column("x", dtype=pl.Int8), min_size=1))
def test_transform_assignments_inside_branch(df: pl.DataFrame):
x = pl.col("x")
assert_frame_equal(
df.select(transform_assignments_inside_branch(x).alias("apply")),
df.apply(lambda r: literal_assignments_inside_branch(r[0])),
check_dtype=False,
)


@polarify
def transform_override_default(x: pl.Expr) -> pl.Expr:
def override_default(x):
s = 0
if x > 0:
s = 10
return x * s


def literal_override_default(x: int) -> int:
s = 0
if x > 0:
s = 10
return x * s


@given(df=dataframes(column("x", dtype=pl.Int8), min_size=1))
def test_transform_override_default(df: pl.DataFrame):
x = pl.col("x")
assert_frame_equal(
df.select(transform_override_default(x).alias("apply")),
df.apply(lambda r: literal_override_default(r[0])),
check_dtype=False,
)


@polarify
def transform_no_if_else(x: pl.Expr) -> pl.Expr:
def no_if_else(x):
s = x * 10
k = x - 3
k = k * 2
# k = k * 2
return s * k


def literal_no_if_else(x: int) -> int:
s = x * 10
k = x - 3
k = k * 2
return s * k
functions = [
# signum,
# early_return,
# assign_both_branches,
# multiple_if_else,
# nested_if_else,
# assignments_inside_branch,
# override_default,
no_if_else,
]


@pytest.fixture(scope="module", params=functions)
def test_funcs(request):
return polarify(request.param), request.param


@given(df=dataframes(column("x", dtype=pl.Int8), min_size=1))
def test_transform_no_if_else(df: pl.DataFrame):
def test_transform_function(df: pl.DataFrame, test_funcs):
x = pl.col("x")
transformed_func, original_func = test_funcs
assert_frame_equal(
df.select(transform_no_if_else(x).alias("apply")),
df.apply(lambda r: literal_no_if_else(r[0])),
df.select(transformed_func(x).alias("apply")),
df.apply(lambda r: original_func(r[0])),
check_dtype=False,
)

0 comments on commit 7c7fc7a

Please sign in to comment.