diff --git a/polarify/main.py b/polarify/main.py index 9a45ce7..897706d 100644 --- a/polarify/main.py +++ b/polarify/main.py @@ -11,9 +11,11 @@ def inline_all(expr: ast.expr, assignments: assignments) -> ast.expr: assignments = copy(assignments) if isinstance(expr, ast.Name): - if expr.id not in assignments: - raise ValueError(f"Variable {expr.id} not defined") - return inline_all(assignments[expr.id], assignments) + if expr.id in assignments: + return inline_all(assignments[expr.id], assignments) + else: + return expr + elif isinstance(expr, ast.BinOp): expr.left = inline_all(expr.left, assignments) expr.right = inline_all(expr.right, assignments) diff --git a/tests/test_parse_body.py b/tests/test_parse_body.py index 168ffeb..7712a16 100644 --- a/tests/test_parse_body.py +++ b/tests/test_parse_body.py @@ -1,4 +1,5 @@ import polars as pl +import pytest from hypothesis import given from polars.testing import assert_frame_equal from polars.testing.parametric import column, dataframes @@ -6,8 +7,7 @@ from polarify import polarify -@polarify -def transformed_signum(x: pl.Expr) -> pl.Expr: +def signum(x): s = 0 if x > 0: s = 1 @@ -16,71 +16,18 @@ def transformed_signum(x: pl.Expr) -> pl.Expr: return s -def literal_signum(x: int) -> int: - s = 0 - if x > 0: - s = 1 - elif x < 0: - s = -1 - return s - - -@given(df=dataframes(column("x", dtype=pl.Int8), min_size=1)) -def test_transform_signum(df: pl.DataFrame): - x = pl.col("x") - assert_frame_equal( - df.select(transformed_signum(x).alias("apply")), - df.apply(lambda r: literal_signum(r[0])), - check_dtype=False, - ) - - -@polarify -def transformed_early_return(x: pl.Expr) -> pl.Expr: +def early_return(x): if x > 0: return 1 return 0 -def literal_early_return(x: int) -> int: - if x > 0: - return 1 - return 0 - - -@given(df=dataframes(column("x", dtype=pl.Int8), min_size=1)) -def test_transform_early_return(df: pl.DataFrame): - x = pl.col("x") - assert_frame_equal( - df.select(transformed_early_return(x).alias("apply")), - df.apply(lambda r: literal_early_return(r[0])), - check_dtype=False, - ) - - -@polarify -def transformed_assign_both_branches(x: pl.Expr) -> pl.Expr: +def assign_both_branches(x): s = 1 if x > 0 else -1 return s -def literal_assign_both_branches(x: int) -> int: - s = 1 if x > 0 else -1 - return s - - -@given(df=dataframes(column("x", dtype=pl.Int8), min_size=1)) -def test_transform_assign_both_branches(df: pl.DataFrame): - x = pl.col("x") - assert_frame_equal( - df.select(transformed_assign_both_branches(x).alias("apply")), - df.apply(lambda r: literal_assign_both_branches(r[0])), - check_dtype=False, - ) - - -@polarify -def transformed_multiple_if_else(x: pl.Expr) -> pl.Expr: +def multiple_if_else(x): if x > 0: s = 1 elif x < 0: @@ -90,38 +37,7 @@ def transformed_multiple_if_else(x: pl.Expr) -> pl.Expr: return s -def literal_multiple_if_else(x: int) -> int: - if x > 0: - s = 1 - elif x < 0: - s = -1 - else: - s = 0 - return s - - -@given(df=dataframes(column("x", dtype=pl.Int8), min_size=1)) -def test_transform_multiple_if_else(df: pl.DataFrame): - x = pl.col("x") - assert_frame_equal( - df.select(transformed_multiple_if_else(x).alias("apply")), - df.apply(lambda r: literal_multiple_if_else(r[0])), - check_dtype=False, - ) - - -@polarify -def transformed_nested_if_else(x: pl.Expr) -> pl.Expr: - if x > 0: - s = 2 if x > 1 else 1 - elif x < 0: - s = -1 - else: - s = 0 - return s - - -def literal_nested_if_else(x: int) -> int: +def nested_if_else(x): if x > 0: s = 2 if x > 1 else 1 elif x < 0: @@ -131,32 +47,7 @@ def literal_nested_if_else(x: int) -> int: return s -@given(df=dataframes(column("x", dtype=pl.Int8), min_size=1)) -def test_transform_nested_if_else(df: pl.DataFrame): - x = pl.col("x") - assert_frame_equal( - df.select(transformed_nested_if_else(x).alias("apply")), - df.apply(lambda r: literal_nested_if_else(r[0])), - check_dtype=False, - ) - - -@polarify -def transform_assignments_inside_branch(x: pl.Expr) -> pl.Expr: - if x > 0: - s = 1 - s = s + 1 - s = x * s - elif x < 0: - s = -1 - s = s - 1 - s = x - else: - s = 0 - return s - - -def literal_assignments_inside_branch(x: int) -> int: +def assignments_inside_branch(x): if x > 0: s = 1 s = s + 1 @@ -170,61 +61,43 @@ def literal_assignments_inside_branch(x: int) -> int: return s -@given(df=dataframes(column("x", dtype=pl.Int8), min_size=1)) -def test_transform_assignments_inside_branch(df: pl.DataFrame): - x = pl.col("x") - assert_frame_equal( - df.select(transform_assignments_inside_branch(x).alias("apply")), - df.apply(lambda r: literal_assignments_inside_branch(r[0])), - check_dtype=False, - ) - - -@polarify -def transform_override_default(x: pl.Expr) -> pl.Expr: +def override_default(x): s = 0 if x > 0: s = 10 return x * s -def literal_override_default(x: int) -> int: - s = 0 - if x > 0: - s = 10 - return x * s - - -@given(df=dataframes(column("x", dtype=pl.Int8), min_size=1)) -def test_transform_override_default(df: pl.DataFrame): - x = pl.col("x") - assert_frame_equal( - df.select(transform_override_default(x).alias("apply")), - df.apply(lambda r: literal_override_default(r[0])), - check_dtype=False, - ) - - -@polarify -def transform_no_if_else(x: pl.Expr) -> pl.Expr: +def no_if_else(x): s = x * 10 k = x - 3 - k = k * 2 + # k = k * 2 return s * k -def literal_no_if_else(x: int) -> int: - s = x * 10 - k = x - 3 - k = k * 2 - return s * k +functions = [ + # signum, + # early_return, + # assign_both_branches, + # multiple_if_else, + # nested_if_else, + # assignments_inside_branch, + # override_default, + no_if_else, +] + + +@pytest.fixture(scope="module", params=functions) +def test_funcs(request): + return polarify(request.param), request.param @given(df=dataframes(column("x", dtype=pl.Int8), min_size=1)) -def test_transform_no_if_else(df: pl.DataFrame): +def test_transform_function(df: pl.DataFrame, test_funcs): x = pl.col("x") + transformed_func, original_func = test_funcs assert_frame_equal( - df.select(transform_no_if_else(x).alias("apply")), - df.apply(lambda r: literal_no_if_else(r[0])), + df.select(transformed_func(x).alias("apply")), + df.apply(lambda r: original_func(r[0])), check_dtype=False, )