diff --git a/daft/internal/rule.py b/daft/internal/rule.py deleted file mode 100644 index 29a42fa8fb..0000000000 --- a/daft/internal/rule.py +++ /dev/null @@ -1,51 +0,0 @@ -from __future__ import annotations - -from typing import Callable, Generic, Optional, TypeVar - -from daft.internal.treenode import TreeNode - -TreeNodeType = TypeVar("TreeNodeType", bound="TreeNode") - -RuleFn = Callable[[TreeNodeType, TreeNodeType], Optional[TreeNodeType]] - - -def get_all_subclasses(input_type: type) -> list[type]: - result = [input_type] - - def helper(t: type): - subclasses = t.__subclasses__() - result.extend(subclasses) - for sc in subclasses: - helper(sc) - - helper(input_type) - return result - - -class Rule(Generic[TreeNodeType]): - def __init__(self) -> None: - self._fn_registry: dict[tuple[type[TreeNodeType], type[TreeNodeType]], RuleFn] = dict() - - def register_fn(self, parent_type: type, child_type: type, fn: RuleFn, override: bool = False) -> None: - for p_subclass in get_all_subclasses(parent_type): - for c_subtype in get_all_subclasses(child_type): - type_tuple = (p_subclass, c_subtype) - if type_tuple in self._fn_registry: - if override: - self._fn_registry[type_tuple] = fn - else: - raise ValueError(f"Rule already registered for {type_tuple}") - else: - self._fn_registry[type_tuple] = fn - - def dispatch_fn(self, parent: TreeNodeType, child: TreeNodeType) -> RuleFn | None: - type_tuple = (type(parent), type(child)) - if type_tuple not in self._fn_registry: - return None - return self._fn_registry.get(type_tuple, None) - - def apply(self, parent: TreeNodeType, child: TreeNodeType) -> TreeNodeType | None: - fn = self.dispatch_fn(parent, child) - if fn is None: - return None - return fn(parent, child) diff --git a/daft/internal/rule_runner.py b/daft/internal/rule_runner.py deleted file mode 100644 index 0e0cb32249..0000000000 --- a/daft/internal/rule_runner.py +++ /dev/null @@ -1,67 +0,0 @@ -from __future__ import annotations - -import logging -from dataclasses import dataclass -from typing import Generic, TypeVar - -from daft.internal.rule import Rule -from daft.internal.treenode import TreeNode - -logger = logging.getLogger(__name__) - -TreeNodeType = TypeVar("TreeNodeType", bound="TreeNode") - - -@dataclass -class FixedPointPolicy: - num_runs: int - - -Once = FixedPointPolicy(1) - - -@dataclass -class RuleBatch(Generic[TreeNodeType]): - name: str - mode: FixedPointPolicy - rules: list[Rule[TreeNodeType]] - - -class RuleRunner(Generic[TreeNodeType]): - def __init__(self, batches: list[RuleBatch[TreeNodeType]]) -> None: - self._batches = batches - - def optimize(self, root: TreeNodeType) -> TreeNodeType: - from copy import deepcopy - - root = deepcopy(root) - for batch in self._batches: - root = self._run_single_batch(root, batch) - return root - - def __call__(self, root: TreeNodeType) -> TreeNodeType: - return self.optimize(root) - - def _run_single_batch(self, root: TreeNodeType, batch: RuleBatch) -> TreeNodeType: - logger.debug(f"Running optimizer batch: {batch.name}") - max_runs = batch.mode.num_runs - applied_least_one_rule = False - for i in range(max_runs): - if i > 0 and not applied_least_one_rule: - logger.debug(f"Optimizer batch: {batch.name} terminating at iteration {i}. No rules applied") - break - logger.debug(f"Running optimizer batch: {batch.name}. Iteration {i} out of maximum {max_runs}") - applied_least_one_rule = False - for rule in batch.rules: - result = root.apply_and_trickle_down(rule) - if result is not None: - root = result - applied_least_one_rule = True - - else: - if applied_least_one_rule: - logger.debug( - f"Optimizer Batch {batch.name} reached max iteration {max_runs} and had changes in the last iteration" - ) - - return root diff --git a/daft/internal/treenode.py b/daft/internal/treenode.py deleted file mode 100644 index 9de36ac300..0000000000 --- a/daft/internal/treenode.py +++ /dev/null @@ -1,110 +0,0 @@ -from __future__ import annotations - -import logging -import os -import typing -from typing import TYPE_CHECKING, Generic, List, TypeVar, cast - -if TYPE_CHECKING: - from daft.internal.rule import Rule - -logger = logging.getLogger(__name__) - -TreeNodeType = TypeVar("TreeNodeType", bound="TreeNode") - - -class TreeNode(Generic[TreeNodeType]): - _registered_children: list[TreeNodeType] - - def __init__(self) -> None: - self._registered_children: list[TreeNodeType] = [] - - def _children(self) -> list[TreeNodeType]: - return self._registered_children - - def _register_child(self, child: TreeNodeType) -> int: - self._registered_children.append(child) - return len(self._registered_children) - 1 - - def apply_and_trickle_down(self, rule: Rule[TreeNodeType]) -> TreeNodeType | None: - root = cast(TreeNodeType, self) - continue_looping = True - made_change = False - - # Apply rule to self and its children - while continue_looping: - for child in root._children(): - fn = rule.dispatch_fn(root, child) - - if fn is None: - continue - maybe_new_root = fn(root, child) - - if maybe_new_root is not None: - root = maybe_new_root - made_change = True - break - else: - continue_looping = False - - # Recursively apply_and_trickle_down to children - n_children = len(root._children()) - for i in range(n_children): - maybe_new_child = root._registered_children[i].apply_and_trickle_down(rule) - if maybe_new_child is not None: - root._registered_children[i] = maybe_new_child - made_change = True - - if made_change: - return root - else: - return None - - def to_dot_file(self, filename: str | None = None) -> str: - dot_data = self.to_dot() - base_path = "log" - if filename is None: - os.makedirs(base_path, exist_ok=True) - filename = f"{base_path}/{hash(dot_data)}.dot" - with open(filename, "w") as f: - f.write(dot_data) - logger.info(f"Wrote Dot file to {filename}") - return filename - - def to_dot(self) -> str: - try: - import pydot - except ImportError: - raise ImportError( - "Error while importing pydot: please manually install `pip install pydot` for tree visualizations" - ) - - graph: pydot.Graph = pydot.Dot("TreeNode", graph_type="digraph", bgcolor="white") # type: ignore - counter = 0 - - def recurser(node: TreeNode) -> int: - nonlocal counter - desc = repr(node) - my_id = counter - myself = pydot.Node(my_id, label=f"{desc}") - graph.add_node(myself) - counter += 1 - for child in node._children(): - child_id = recurser(child) - edge = pydot.Edge(str(my_id), str(child_id), color="black") - graph.add_edge(edge) # type: ignore - return my_id - - recurser(self) - return graph.to_string() # type: ignore - - def post_order(self) -> list[TreeNodeType]: - nodes = [] - - def helper(curr: TreeNode[TreeNodeType]) -> None: - for child in curr._children(): - helper(child) - nodes.append(curr) - - helper(self) - return typing.cast(List[TreeNodeType], nodes)