-
Notifications
You must be signed in to change notification settings - Fork 17
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Ayush Kamat <[email protected]>
- Loading branch information
1 parent
0243e08
commit afb4d96
Showing
5 changed files
with
425 additions
and
354 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,300 @@ | ||
try: | ||
from functools import cache | ||
except ImportError: | ||
from functools import lru_cache as cache | ||
|
||
import json | ||
from dataclasses import dataclass, field | ||
from enum import Enum | ||
from pathlib import Path | ||
from typing import Dict, List, Optional, Set, Tuple, TypedDict | ||
|
||
from typing_extensions import NotRequired, Self | ||
|
||
from latch_cli.utils import identifier_from_str | ||
|
||
|
||
class VertexType(str, Enum): | ||
Process = "Process" | ||
Operator = "Operator" | ||
SubWorkflow = "SubWorkflow" | ||
Conditional = "Conditional" | ||
Generator = "Generator" | ||
Input = "Input" | ||
|
||
|
||
@dataclass(frozen=True) | ||
class Vertex: | ||
id: str | ||
label: str | ||
type: VertexType | ||
|
||
|
||
@dataclass(frozen=True) | ||
class Edge: | ||
label: str | ||
src: str | ||
dest: str | ||
branch: Optional[bool] = None | ||
|
||
|
||
class _VertexContentJson(TypedDict): | ||
id: str | ||
label: str | ||
type: VertexType | ||
processMeta: Optional[Dict] | ||
|
||
|
||
class _VertexJson(TypedDict): | ||
content: _VertexContentJson | ||
|
||
|
||
class _EdgeContentJson(TypedDict): | ||
label: str | ||
src: str | ||
dest: str | ||
branch: NotRequired[Optional[bool]] | ||
|
||
|
||
class _EdgeJson(TypedDict): | ||
content: _EdgeContentJson | ||
|
||
|
||
class _DAGJson(TypedDict): | ||
vertices: List[_VertexJson] | ||
edges: List[_EdgeJson] | ||
|
||
|
||
@dataclass(frozen=True) | ||
class DAG: | ||
vertices: List[Vertex] = field(hash=False) | ||
edges: List[Edge] = field(hash=False) | ||
|
||
@classmethod | ||
def from_path(cls, p: Path) -> Self: | ||
if not p.exists(): | ||
raise # todo(ayush): better errors | ||
|
||
payload: _DAGJson = json.loads(p.read_text()) | ||
|
||
vertices: List[Vertex] = [] | ||
for v in payload["vertices"]: | ||
c = v["content"] | ||
vertices.append( | ||
Vertex( | ||
id=c["id"], | ||
label=identifier_from_str(c["label"])[:128], | ||
type=c["type"], | ||
) | ||
) | ||
|
||
edges: List[Edge] = [] | ||
edge_set: Set[Tuple[str, str]] = set() | ||
for e in payload["edges"]: | ||
c = e["content"] | ||
t = (c["src"], c["dest"]) | ||
|
||
if t in edge_set: | ||
# disallow multiple edges | ||
continue | ||
|
||
edges.append(Edge(**c)) | ||
edge_set.add(t) | ||
|
||
return cls(vertices, edges) | ||
|
||
@cache | ||
def _vertices_by_id(self) -> Dict[str, Vertex]: | ||
res: Dict[str, Vertex] = {} | ||
for v in self.vertices: | ||
res[v.id] = v | ||
|
||
return res | ||
|
||
@cache | ||
def src(self, e: Edge) -> Vertex: | ||
return self._vertices_by_id()[e.src] | ||
|
||
@cache | ||
def dest(self, e: Edge) -> Vertex: | ||
return self._vertices_by_id()[e.dest] | ||
|
||
@cache | ||
def ancestors(self) -> Dict[Vertex, List[Vertex]]: | ||
res: Dict[Vertex, List[Vertex]] = {} | ||
for v in self.vertices: | ||
res[v] = [] | ||
|
||
by_id = self._vertices_by_id() | ||
for edge in self.edges: | ||
res[by_id[edge.dest]].append(by_id[edge.src]) | ||
|
||
return res | ||
|
||
@cache | ||
def inbound_edges(self) -> Dict[Vertex, List[Edge]]: | ||
res: Dict[Vertex, List[Edge]] = {} | ||
for v in self.vertices: | ||
res[v] = [] | ||
|
||
by_id = self._vertices_by_id() | ||
for edge in self.edges: | ||
res[by_id[edge.dest]].append(edge) | ||
|
||
return res | ||
|
||
@cache | ||
def descendants(self) -> Dict[Vertex, List[Vertex]]: | ||
res: Dict[Vertex, List[Vertex]] = {} | ||
for v in self.vertices: | ||
res[v] = [] | ||
|
||
by_id = self._vertices_by_id() | ||
for edge in self.edges: | ||
res[by_id[edge.src]].append(by_id[edge.dest]) | ||
|
||
return res | ||
|
||
@cache | ||
def outbound_edges(self) -> Dict[Vertex, List[Edge]]: | ||
res: Dict[Vertex, List[Edge]] = {} | ||
for v in self.vertices: | ||
res[v] = [] | ||
|
||
by_id = self._vertices_by_id() | ||
for edge in self.edges: | ||
res[by_id[edge.src]].append(edge) | ||
|
||
return res | ||
|
||
@property | ||
@cache | ||
def source_vertices(self) -> List[Vertex]: | ||
res: List[Vertex] = [] | ||
|
||
for v, upstream in self.ancestors().items(): | ||
if len(upstream) != 0: | ||
continue | ||
|
||
res.append(v) | ||
|
||
return res | ||
|
||
@property | ||
@cache | ||
def sink_vertices(self) -> List[Vertex]: | ||
res: List[Vertex] = [] | ||
|
||
for v, downstream in self.descendants().items(): | ||
if len(downstream) != 0: | ||
continue | ||
|
||
res.append(v) | ||
|
||
return res | ||
|
||
@classmethod | ||
def _resolve_subworkflows_helper( | ||
cls, | ||
wf_name: str, | ||
dags: Dict[str, Self], | ||
sub_wf_dependencies: Dict[str, List[str]], | ||
): | ||
for dep in sub_wf_dependencies[wf_name]: | ||
cls._resolve_subworkflows_helper(dep, dags, sub_wf_dependencies) | ||
|
||
dag = dags[wf_name] | ||
|
||
new_vertices: List[Vertex] = [] | ||
new_edges: List[Edge] = [] | ||
for v in dag.vertices: | ||
if v.type != VertexType.SubWorkflow: | ||
new_vertices.append(v) | ||
continue | ||
|
||
sub_dag = dags[v.label] | ||
for sub_v in sub_dag.vertices: | ||
new_vertices.append( | ||
Vertex( | ||
id="_".join([v.id, sub_v.id]), | ||
label=sub_v.label, | ||
type=sub_v.type, | ||
) | ||
) | ||
|
||
for sub_e in sub_dag.edges: | ||
new_edges.append( | ||
Edge( | ||
label=sub_e.label, | ||
src="_".join([v.id, sub_e.src]), | ||
dest="_".join([v.id, sub_e.dest]), | ||
) | ||
) | ||
|
||
ids = set(v.id for v in new_vertices) | ||
for e in dag.edges: | ||
if e.src in ids: | ||
srcs = [e.src] | ||
else: | ||
sub_dag = dags[dag.src(e).label] | ||
|
||
srcs = ["_".join([e.src, v.id]) for v in sub_dag.sink_vertices] | ||
|
||
if e.dest in ids: | ||
dests = [e.dest] | ||
else: | ||
sub_dag = dags[dag.dest(e).label] | ||
|
||
dests = ["_".join([e.dest, v.id]) for v in sub_dag.source_vertices] | ||
|
||
for src in srcs: | ||
for dest in dests: | ||
new_edges.append( | ||
Edge( | ||
label=e.label, | ||
src=src, | ||
dest=dest, | ||
) | ||
) | ||
|
||
dags[wf_name] = cls(new_vertices, new_edges) | ||
|
||
@classmethod | ||
def resolve_subworkflows(cls, dags: Dict[str, Self]) -> Dict[str, Self]: | ||
dependencies: Dict[str, List[str]] = {} | ||
sources = set(dags.keys()) | ||
|
||
for wf_name, dag in dags.items(): | ||
deps: List[str] = [] | ||
for v in dag.vertices: | ||
if v.type != VertexType.SubWorkflow: | ||
continue | ||
|
||
deps.append(v.label) | ||
sources.discard(v.label) | ||
|
||
dependencies[wf_name] = deps | ||
|
||
# todo(ayush): idk the time/space complexity of this but its certainly not great | ||
resolved_dags = dags.copy() | ||
res: Dict[str, Self] = {} | ||
for source in sources: | ||
cls._resolve_subworkflows_helper(source, resolved_dags, dependencies) | ||
res[source] = resolved_dags[source] | ||
|
||
return res | ||
|
||
def _toposort_helper(self, cur: Vertex, res: List[Vertex]): | ||
for x in self.ancestors()[cur]: | ||
self._toposort_helper(x, res) | ||
|
||
res.append(cur) | ||
|
||
@cache | ||
def toposorted(self) -> List[Vertex]: | ||
res = [] | ||
|
||
for sink in self.sink_vertices: | ||
self._toposort_helper(sink, res) | ||
|
||
return res |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.