Skip to content

Commit

Permalink
fix if / else var assignment lookup
Browse files Browse the repository at this point in the history
  • Loading branch information
0xbe7a committed Jul 29, 2023
1 parent 7c7fc7a commit f8393d0
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 18 deletions.
1 change: 1 addition & 0 deletions polarify/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def polarify(func):

# Unparse the modified AST back into source code
new_func_code = ast.unparse(tree)
print(new_func_code)

# Execute the new function code in the original function's globals
exec_globals = func.__globals__
Expand Down
30 changes: 21 additions & 9 deletions polarify/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
# TODO: make Walruss throw ValueError
# TODO: Switch

assignments = dict[str, ast.expr]
Assignments = dict[str, ast.expr]


def inline_all(expr: ast.expr, assignments: assignments) -> ast.expr:
def inline_all(expr: ast.expr, assignments: Assignments) -> ast.expr:
assignments = copy(assignments)
if isinstance(expr, ast.Name):
if expr.id in assignments:
Expand Down Expand Up @@ -39,7 +39,7 @@ def is_returning_body(stmts: list[ast.stmt]) -> bool:
return False


def handle_assign(stmt: ast.Assign, assignments: assignments) -> assignments:
def handle_assign(stmt: ast.Assign, assignments: Assignments) -> Assignments:
assignments = copy(assignments)
diff_assignments = {}

Expand Down Expand Up @@ -67,28 +67,40 @@ def handle_assign(stmt: ast.Assign, assignments: assignments) -> assignments:
return diff_assignments


def handle_non_returning_if(stmt: ast.If, assignments: assignments) -> assignments:
def handle_non_returning_if(stmt: ast.If, assignments: Assignments) -> Assignments:
assignments = copy(assignments)
assert not is_returning_body(stmt.orelse) and not is_returning_body(stmt.body)
test = inline_all(stmt.test, assignments)

diff_assignments = {}
all_vars_changed_in_body = get_all_vars_changed_in_body(stmt.body, assignments)
all_vars_changed_in_orelse = get_all_vars_changed_in_body(stmt.orelse, assignments)

def updated_or_default_assignments(var: str, diff: Assignments) -> ast.expr:
if var in diff:
return diff[var]
elif var in assignments:
return assignments[var]
else:
raise ValueError(
f"Variable {var} has to be either defined in"
" all branches or have a previous defintion"
)

for var in all_vars_changed_in_body | all_vars_changed_in_orelse:
expr = build_polars_when_then_otherwise(
test,
all_vars_changed_in_body.get(var, assignments[var]),
all_vars_changed_in_orelse.get(var, assignments[var]),
updated_or_default_assignments(var, all_vars_changed_in_body),
updated_or_default_assignments(var, all_vars_changed_in_orelse),
)
assignments[var] = expr
diff_assignments[var] = expr
return diff_assignments


def get_all_vars_changed_in_body(
body: list[ast.stmt], assignments: assignments
) -> assignments:
body: list[ast.stmt], assignments: Assignments
) -> Assignments:
assignments = copy(assignments)
diff_assignments = {}

Expand Down Expand Up @@ -132,7 +144,7 @@ def build_polars_when_then_otherwise(test: ast.expr, then: ast.expr, orelse: ast


def parse_body(
full_body: list[ast.stmt], assignments: Union[assignments, None] = None
full_body: list[ast.stmt], assignments: Union[Assignments, None] = None
) -> ast.expr:
if assignments is None:
assignments = {}
Expand Down
46 changes: 37 additions & 9 deletions tests/test_parse_body.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
# ruff: noqa

import polars as pl
import pytest
from hypothesis import given
from polars.testing import assert_frame_equal
from polars.testing.parametric import column, dataframes
from hypothesis.strategies import integers

from polarify import polarify

Expand All @@ -23,6 +26,14 @@ def early_return(x):


def assign_both_branches(x):
if x > 0:
s = 1
else:
s = -1
return s


def if_expr(x):
s = 1 if x > 0 else -1
return s

Expand All @@ -38,6 +49,19 @@ def multiple_if_else(x):


def nested_if_else(x):
if x > 0:
if x > 1:
s = 2
else:
s = 1
elif x < 0:
s = -1
else:
s = 0
return s


def nested_if_else_expr(x):
if x > 0:
s = 2 if x > 1 else 1
elif x < 0:
Expand Down Expand Up @@ -71,18 +95,20 @@ def override_default(x):
def no_if_else(x):
s = x * 10
k = x - 3
# k = k * 2
k = k * 2
return s * k


functions = [
# signum,
# early_return,
# assign_both_branches,
# multiple_if_else,
# nested_if_else,
# assignments_inside_branch,
# override_default,
signum,
early_return,
assign_both_branches,
# if_expr,
multiple_if_else,
nested_if_else,
# nested_if_else_expr,
assignments_inside_branch,
override_default,
no_if_else,
]

Expand All @@ -92,7 +118,9 @@ def test_funcs(request):
return polarify(request.param), request.param


@given(df=dataframes(column("x", dtype=pl.Int8), min_size=1))
@given(
df=dataframes(column("x", dtype=pl.Int64, strategy=integers(-100, 100)), min_size=1)
)
def test_transform_function(df: pl.DataFrame, test_funcs):
x = pl.col("x")
transformed_func, original_func = test_funcs
Expand Down

0 comments on commit f8393d0

Please sign in to comment.