Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[venom]: make dft-pass commutative aware #4358

Merged
merged 14 commits into from
Nov 20, 2024
25 changes: 25 additions & 0 deletions vyper/venom/basicblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@

COMMUTATIVE_INSTRUCTIONS = frozenset(["add", "mul", "smul", "or", "xor", "and", "eq"])

COMPARATOR_INSTRUCTIONS = ("gt", "lt", "sgt", "slt")

if TYPE_CHECKING:
from vyper.venom.function import IRFunction

Expand Down Expand Up @@ -230,6 +232,14 @@ def is_volatile(self) -> bool:
def is_commutative(self) -> bool:
return self.opcode in COMMUTATIVE_INSTRUCTIONS

@property
def is_comparator(self) -> bool:
return self.opcode in COMPARATOR_INSTRUCTIONS

@property
def flippable(self) -> bool:
return self.is_commutative or self.is_comparator

@property
def is_bb_terminator(self) -> bool:
return self.opcode in BB_TERMINATORS
Expand Down Expand Up @@ -282,6 +292,21 @@ def get_outputs(self) -> list[IROperand]:
"""
return [self.output] if self.output else []

def flip(self):
"""
Flip operands for commutative or comparator opcodes
"""
assert self.flippable
self.operands.reverse()

if self.is_commutative:
return

if self.opcode in ("gt", "sgt"):
self.opcode = self.opcode.replace("g", "l")
else:
self.opcode = self.opcode.replace("l", "g")

def replace_operands(self, replacements: dict) -> None:
"""
Update operands with replacements.
Expand Down
86 changes: 39 additions & 47 deletions vyper/venom/passes/dft.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,72 +2,56 @@

import vyper.venom.effects as effects
from vyper.utils import OrderedSet
from vyper.venom.analysis import DFGAnalysis, IRAnalysesCache, LivenessAnalysis
from vyper.venom.analysis import DFGAnalysis, LivenessAnalysis
from vyper.venom.basicblock import IRBasicBlock, IRInstruction
from vyper.venom.function import IRFunction
from vyper.venom.passes.base_pass import IRPass


class DFTPass(IRPass):
function: IRFunction
inst_offspring: dict[IRInstruction, OrderedSet[IRInstruction]]
data_offspring: dict[IRInstruction, OrderedSet[IRInstruction]]
visited_instructions: OrderedSet[IRInstruction]
ida: dict[IRInstruction, OrderedSet[IRInstruction]]

def __init__(self, analyses_cache: IRAnalysesCache, function: IRFunction):
super().__init__(analyses_cache, function)
self.inst_offspring = {}
# "data dependency analysis"
dda: dict[IRInstruction, OrderedSet[IRInstruction]]
# "effect dependency analysis"
eda: dict[IRInstruction, OrderedSet[IRInstruction]]

def run_pass(self) -> None:
self.inst_offspring = {}
self.data_offspring = {}
self.visited_instructions: OrderedSet[IRInstruction] = OrderedSet()

self.dfg = self.analyses_cache.request_analysis(DFGAnalysis)
basic_blocks = list(self.function.get_basic_blocks())

self.function.clear_basic_blocks()
for bb in basic_blocks:
for bb in self.function.get_basic_blocks():
self._process_basic_block(bb)

self.analyses_cache.invalidate_analysis(LivenessAnalysis)

def _process_basic_block(self, bb: IRBasicBlock) -> None:
self.function.append_basic_block(bb)

self._calculate_dependency_graphs(bb)
self.instructions = list(bb.pseudo_instructions)
non_phi_instructions = list(bb.non_phi_instructions)

self.visited_instructions = OrderedSet()
for inst in non_phi_instructions:
self._calculate_instruction_offspring(inst)
for inst in bb.instructions:
self._calculate_data_offspring(inst)

# Compute entry points in the graph of instruction dependencies
entry_instructions: OrderedSet[IRInstruction] = OrderedSet(non_phi_instructions)
for inst in non_phi_instructions:
to_remove = self.ida.get(inst, OrderedSet())
if len(to_remove) > 0:
entry_instructions.dropmany(to_remove)
to_remove = self.dda.get(inst, OrderedSet()) | self.eda.get(inst, OrderedSet())
entry_instructions.dropmany(to_remove)

entry_instructions_list = list(entry_instructions)

# Move the terminator instruction to the end of the list
self._move_terminator_to_end(entry_instructions_list)

self.visited_instructions = OrderedSet()
for inst in entry_instructions_list:
self._process_instruction_r(self.instructions, inst)

bb.instructions = self.instructions
assert bb.is_terminated, f"Basic block should be terminated {bb}"

def _move_terminator_to_end(self, instructions: list[IRInstruction]) -> None:
terminator = next((inst for inst in instructions if inst.is_bb_terminator), None)
if terminator is None:
raise ValueError(f"Basic block should have a terminator instruction {self.function}")
instructions.remove(terminator)
instructions.append(terminator)

def _process_instruction_r(self, instructions: list[IRInstruction], inst: IRInstruction):
if inst in self.visited_instructions:
return
Expand All @@ -76,14 +60,23 @@ def _process_instruction_r(self, instructions: list[IRInstruction], inst: IRInst
if inst.is_pseudo:
return

children = list(self.ida[inst])
children = list(self.dda[inst] | self.eda[inst])

def key(x):
cost = inst.operands.index(x.output) if x.output in inst.operands else 0
return cost - len(self.inst_offspring[x]) * 0.5
def cost(x: IRInstruction) -> int | float:
if x in self.eda[inst] or inst.flippable:
ret = -1 * int(len(self.data_offspring[x]) > 0)
else:
assert x in self.dda[inst] # sanity check
assert x.output is not None # help mypy
ret = inst.operands.index(x.output)
return ret

# heuristic: sort by size of child dependency graph
children.sort(key=key)
orig_children = children.copy()
children.sort(key=cost)

if inst.flippable and (orig_children != children):
inst.flip()

for dep_inst in children:
self._process_instruction_r(instructions, dep_inst)
Expand All @@ -92,7 +85,8 @@ def key(x):

def _calculate_dependency_graphs(self, bb: IRBasicBlock) -> None:
# ida: instruction dependency analysis
self.ida = defaultdict(OrderedSet)
self.dda = defaultdict(OrderedSet)
self.eda = defaultdict(OrderedSet)

non_phis = list(bb.non_phi_instructions)

Expand All @@ -106,33 +100,31 @@ def _calculate_dependency_graphs(self, bb: IRBasicBlock) -> None:
for op in inst.operands:
dep = self.dfg.get_producing_instruction(op)
if dep is not None and dep.parent == bb:
self.ida[inst].add(dep)
self.dda[inst].add(dep)

write_effects = inst.get_write_effects()
read_effects = inst.get_read_effects()

for write_effect in write_effects:
if write_effect in last_read_effects:
self.ida[inst].add(last_read_effects[write_effect])
self.eda[inst].add(last_read_effects[write_effect])
last_write_effects[write_effect] = inst

for read_effect in read_effects:
if read_effect in last_write_effects and last_write_effects[read_effect] != inst:
self.ida[inst].add(last_write_effects[read_effect])
self.eda[inst].add(last_write_effects[read_effect])
last_read_effects[read_effect] = inst

def _calculate_instruction_offspring(self, inst: IRInstruction):
if inst in self.inst_offspring:
return self.inst_offspring[inst]
def _calculate_data_offspring(self, inst: IRInstruction):
if inst in self.data_offspring:
return self.data_offspring[inst]

self.inst_offspring[inst] = self.ida[inst].copy()
self.data_offspring[inst] = self.dda[inst].copy()

deps = self.ida[inst]
deps = self.dda[inst]
for dep_inst in deps:
assert inst.parent == dep_inst.parent
if dep_inst.opcode == "store":
continue
res = self._calculate_instruction_offspring(dep_inst)
self.inst_offspring[inst] |= res
res = self._calculate_data_offspring(dep_inst)
self.data_offspring[inst] |= res

return self.inst_offspring[inst]
return self.data_offspring[inst]
Loading