diff --git a/hy/compiler.py b/hy/compiler.py index f3cc9ac95..e9037beac 100644 --- a/hy/compiler.py +++ b/hy/compiler.py @@ -375,6 +375,7 @@ def __init__(self, module_name): self.anon_var_count = 0 self.imports = defaultdict(set) self.module_name = module_name + self.temp_if = None if not module_name.startswith("hy.core"): # everything in core needs to be explicit. load_stdlib() @@ -1004,8 +1005,16 @@ def compile_if(self, expression): body = self.compile(expression.pop(0)) orel = Result() + nested = root = False if expression: - orel = self.compile(expression.pop(0)) + orel_expr = expression.pop(0) + if isinstance(orel_expr, HyExpression) and isinstance(orel_expr[0], + HySymbol) and orel_expr[0] == 'if': + # Nested ifs: don't waste temporaries + root = self.temp_if is None + nested = True + self.temp_if = self.temp_if or self.get_anon_var() + orel = self.compile(orel_expr) # We want to hoist the statements from the condition ret = cond @@ -1013,7 +1022,7 @@ def compile_if(self, expression): if body.stmts or orel.stmts: # We have statements in our bodies # Get a temporary variable for the result storage - var = self.get_anon_var() + var = self.temp_if or self.get_anon_var() name = ast.Name(id=ast_str(var), arg=ast_str(var), ctx=ast.Store(), lineno=expression.start_line, @@ -1026,10 +1035,12 @@ def compile_if(self, expression): col_offset=expression.start_column) # and of the else clause - orel += ast.Assign(targets=[name], - value=orel.force_expr, - lineno=expression.start_line, - col_offset=expression.start_column) + if not nested or not orel.stmts or (not root and + var != self.temp_if): + orel += ast.Assign(targets=[name], + value=orel.force_expr, + lineno=expression.start_line, + col_offset=expression.start_column) # Then build the if ret += ast.If(test=ret.force_expr, @@ -1052,6 +1063,10 @@ def compile_if(self, expression): orelse=orel.force_expr, lineno=expression.start_line, col_offset=expression.start_column) + + if root: + self.temp_if = None + return ret @builds("break")