diff --git a/vyper/venom/basicblock.py b/vyper/venom/basicblock.py index c0abcefcb0..968ce42bdf 100644 --- a/vyper/venom/basicblock.py +++ b/vyper/venom/basicblock.py @@ -84,6 +84,8 @@ COMMUTATIVE_INSTRUCTIONS = frozenset(["add", "mul", "smul", "or", "xor", "and", "eq"]) +COMPARATOR_INSTRUCTIONS = ("gt", "lt", "sgt", "slt") + if TYPE_CHECKING: from vyper.venom.function import IRFunction @@ -230,6 +232,14 @@ def is_volatile(self) -> bool: def is_commutative(self) -> bool: return self.opcode in COMMUTATIVE_INSTRUCTIONS + @property + def is_comparator(self) -> bool: + return self.opcode in COMPARATOR_INSTRUCTIONS + + @property + def flippable(self) -> bool: + return self.is_commutative or self.is_comparator + @property def is_bb_terminator(self) -> bool: return self.opcode in BB_TERMINATORS @@ -282,6 +292,21 @@ def get_outputs(self) -> list[IROperand]: """ return [self.output] if self.output else [] + def flip(self): + """ + Flip operands for commutative or comparator opcodes + """ + assert self.flippable + self.operands.reverse() + + if self.is_commutative: + return + + if self.opcode in ("gt", "sgt"): + self.opcode = self.opcode.replace("g", "l") + else: + self.opcode = self.opcode.replace("l", "g") + def replace_operands(self, replacements: dict) -> None: """ Update operands with replacements. diff --git a/vyper/venom/passes/dft.py b/vyper/venom/passes/dft.py index 2bf82810b6..a8d68ad676 100644 --- a/vyper/venom/passes/dft.py +++ b/vyper/venom/passes/dft.py @@ -2,7 +2,7 @@ import vyper.venom.effects as effects from vyper.utils import OrderedSet -from vyper.venom.analysis import DFGAnalysis, IRAnalysesCache, LivenessAnalysis +from vyper.venom.analysis import DFGAnalysis, LivenessAnalysis from vyper.venom.basicblock import IRBasicBlock, IRInstruction from vyper.venom.function import IRFunction from vyper.venom.passes.base_pass import IRPass @@ -10,50 +10,41 @@ class DFTPass(IRPass): function: IRFunction - inst_offspring: dict[IRInstruction, OrderedSet[IRInstruction]] + data_offspring: dict[IRInstruction, OrderedSet[IRInstruction]] visited_instructions: OrderedSet[IRInstruction] - ida: dict[IRInstruction, OrderedSet[IRInstruction]] - - def __init__(self, analyses_cache: IRAnalysesCache, function: IRFunction): - super().__init__(analyses_cache, function) - self.inst_offspring = {} + # "data dependency analysis" + dda: dict[IRInstruction, OrderedSet[IRInstruction]] + # "effect dependency analysis" + eda: dict[IRInstruction, OrderedSet[IRInstruction]] def run_pass(self) -> None: - self.inst_offspring = {} + self.data_offspring = {} self.visited_instructions: OrderedSet[IRInstruction] = OrderedSet() self.dfg = self.analyses_cache.request_analysis(DFGAnalysis) - basic_blocks = list(self.function.get_basic_blocks()) - self.function.clear_basic_blocks() - for bb in basic_blocks: + for bb in self.function.get_basic_blocks(): self._process_basic_block(bb) self.analyses_cache.invalidate_analysis(LivenessAnalysis) def _process_basic_block(self, bb: IRBasicBlock) -> None: - self.function.append_basic_block(bb) - self._calculate_dependency_graphs(bb) self.instructions = list(bb.pseudo_instructions) non_phi_instructions = list(bb.non_phi_instructions) self.visited_instructions = OrderedSet() - for inst in non_phi_instructions: - self._calculate_instruction_offspring(inst) + for inst in bb.instructions: + self._calculate_data_offspring(inst) # Compute entry points in the graph of instruction dependencies entry_instructions: OrderedSet[IRInstruction] = OrderedSet(non_phi_instructions) for inst in non_phi_instructions: - to_remove = self.ida.get(inst, OrderedSet()) - if len(to_remove) > 0: - entry_instructions.dropmany(to_remove) + to_remove = self.dda.get(inst, OrderedSet()) | self.eda.get(inst, OrderedSet()) + entry_instructions.dropmany(to_remove) entry_instructions_list = list(entry_instructions) - # Move the terminator instruction to the end of the list - self._move_terminator_to_end(entry_instructions_list) - self.visited_instructions = OrderedSet() for inst in entry_instructions_list: self._process_instruction_r(self.instructions, inst) @@ -61,13 +52,6 @@ def _process_basic_block(self, bb: IRBasicBlock) -> None: bb.instructions = self.instructions assert bb.is_terminated, f"Basic block should be terminated {bb}" - def _move_terminator_to_end(self, instructions: list[IRInstruction]) -> None: - terminator = next((inst for inst in instructions if inst.is_bb_terminator), None) - if terminator is None: - raise ValueError(f"Basic block should have a terminator instruction {self.function}") - instructions.remove(terminator) - instructions.append(terminator) - def _process_instruction_r(self, instructions: list[IRInstruction], inst: IRInstruction): if inst in self.visited_instructions: return @@ -76,14 +60,23 @@ def _process_instruction_r(self, instructions: list[IRInstruction], inst: IRInst if inst.is_pseudo: return - children = list(self.ida[inst]) + children = list(self.dda[inst] | self.eda[inst]) - def key(x): - cost = inst.operands.index(x.output) if x.output in inst.operands else 0 - return cost - len(self.inst_offspring[x]) * 0.5 + def cost(x: IRInstruction) -> int | float: + if x in self.eda[inst] or inst.flippable: + ret = -1 * int(len(self.data_offspring[x]) > 0) + else: + assert x in self.dda[inst] # sanity check + assert x.output is not None # help mypy + ret = inst.operands.index(x.output) + return ret # heuristic: sort by size of child dependency graph - children.sort(key=key) + orig_children = children.copy() + children.sort(key=cost) + + if inst.flippable and (orig_children != children): + inst.flip() for dep_inst in children: self._process_instruction_r(instructions, dep_inst) @@ -92,7 +85,8 @@ def key(x): def _calculate_dependency_graphs(self, bb: IRBasicBlock) -> None: # ida: instruction dependency analysis - self.ida = defaultdict(OrderedSet) + self.dda = defaultdict(OrderedSet) + self.eda = defaultdict(OrderedSet) non_phis = list(bb.non_phi_instructions) @@ -106,33 +100,31 @@ def _calculate_dependency_graphs(self, bb: IRBasicBlock) -> None: for op in inst.operands: dep = self.dfg.get_producing_instruction(op) if dep is not None and dep.parent == bb: - self.ida[inst].add(dep) + self.dda[inst].add(dep) write_effects = inst.get_write_effects() read_effects = inst.get_read_effects() for write_effect in write_effects: if write_effect in last_read_effects: - self.ida[inst].add(last_read_effects[write_effect]) + self.eda[inst].add(last_read_effects[write_effect]) last_write_effects[write_effect] = inst for read_effect in read_effects: if read_effect in last_write_effects and last_write_effects[read_effect] != inst: - self.ida[inst].add(last_write_effects[read_effect]) + self.eda[inst].add(last_write_effects[read_effect]) last_read_effects[read_effect] = inst - def _calculate_instruction_offspring(self, inst: IRInstruction): - if inst in self.inst_offspring: - return self.inst_offspring[inst] + def _calculate_data_offspring(self, inst: IRInstruction): + if inst in self.data_offspring: + return self.data_offspring[inst] - self.inst_offspring[inst] = self.ida[inst].copy() + self.data_offspring[inst] = self.dda[inst].copy() - deps = self.ida[inst] + deps = self.dda[inst] for dep_inst in deps: assert inst.parent == dep_inst.parent - if dep_inst.opcode == "store": - continue - res = self._calculate_instruction_offspring(dep_inst) - self.inst_offspring[inst] |= res + res = self._calculate_data_offspring(dep_inst) + self.data_offspring[inst] |= res - return self.inst_offspring[inst] + return self.data_offspring[inst]