Skip to content

Commit

Permalink
registration
Browse files Browse the repository at this point in the history
Signed-off-by: Ayush Kamat <[email protected]>
  • Loading branch information
ayushkamat committed Feb 3, 2024
1 parent 0243e08 commit afb4d96
Show file tree
Hide file tree
Showing 5 changed files with 425 additions and 354 deletions.
6 changes: 3 additions & 3 deletions latch_cli/extras/common/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,9 @@ def serialize(
+ [admin_lp]
]

click.secho("\nSerializing workflow entities", bold=True)
persist_registrable_entities(registrable_entities, output_dir)

if not write_spec:
return

Expand All @@ -352,9 +355,6 @@ def serialize(
cur.parent.mkdir(parents=True, exist_ok=True)
cur.write_text(MessageToJson(entity))

click.secho("\nSerializing workflow entities", bold=True)
persist_registrable_entities(registrable_entities, output_dir)


def binding_data_from_python(
expected_literal_type: type_models.LiteralType,
Expand Down
300 changes: 300 additions & 0 deletions latch_cli/extras/nextflow/dag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,300 @@
try:
from functools import cache
except ImportError:
from functools import lru_cache as cache

import json
from dataclasses import dataclass, field
from enum import Enum
from pathlib import Path
from typing import Dict, List, Optional, Set, Tuple, TypedDict

from typing_extensions import NotRequired, Self

from latch_cli.utils import identifier_from_str


class VertexType(str, Enum):
Process = "Process"
Operator = "Operator"
SubWorkflow = "SubWorkflow"
Conditional = "Conditional"
Generator = "Generator"
Input = "Input"


@dataclass(frozen=True)
class Vertex:
id: str
label: str
type: VertexType


@dataclass(frozen=True)
class Edge:
label: str
src: str
dest: str
branch: Optional[bool] = None


class _VertexContentJson(TypedDict):
id: str
label: str
type: VertexType
processMeta: Optional[Dict]


class _VertexJson(TypedDict):
content: _VertexContentJson


class _EdgeContentJson(TypedDict):
label: str
src: str
dest: str
branch: NotRequired[Optional[bool]]


class _EdgeJson(TypedDict):
content: _EdgeContentJson


class _DAGJson(TypedDict):
vertices: List[_VertexJson]
edges: List[_EdgeJson]


@dataclass(frozen=True)
class DAG:
vertices: List[Vertex] = field(hash=False)
edges: List[Edge] = field(hash=False)

@classmethod
def from_path(cls, p: Path) -> Self:
if not p.exists():
raise # todo(ayush): better errors

payload: _DAGJson = json.loads(p.read_text())

vertices: List[Vertex] = []
for v in payload["vertices"]:
c = v["content"]
vertices.append(
Vertex(
id=c["id"],
label=identifier_from_str(c["label"])[:128],
type=c["type"],
)
)

edges: List[Edge] = []
edge_set: Set[Tuple[str, str]] = set()
for e in payload["edges"]:
c = e["content"]
t = (c["src"], c["dest"])

if t in edge_set:
# disallow multiple edges
continue

edges.append(Edge(**c))
edge_set.add(t)

return cls(vertices, edges)

@cache
def _vertices_by_id(self) -> Dict[str, Vertex]:
res: Dict[str, Vertex] = {}
for v in self.vertices:
res[v.id] = v

return res

@cache
def src(self, e: Edge) -> Vertex:
return self._vertices_by_id()[e.src]

@cache
def dest(self, e: Edge) -> Vertex:
return self._vertices_by_id()[e.dest]

@cache
def ancestors(self) -> Dict[Vertex, List[Vertex]]:
res: Dict[Vertex, List[Vertex]] = {}
for v in self.vertices:
res[v] = []

by_id = self._vertices_by_id()
for edge in self.edges:
res[by_id[edge.dest]].append(by_id[edge.src])

return res

@cache
def inbound_edges(self) -> Dict[Vertex, List[Edge]]:
res: Dict[Vertex, List[Edge]] = {}
for v in self.vertices:
res[v] = []

by_id = self._vertices_by_id()
for edge in self.edges:
res[by_id[edge.dest]].append(edge)

return res

@cache
def descendants(self) -> Dict[Vertex, List[Vertex]]:
res: Dict[Vertex, List[Vertex]] = {}
for v in self.vertices:
res[v] = []

by_id = self._vertices_by_id()
for edge in self.edges:
res[by_id[edge.src]].append(by_id[edge.dest])

return res

@cache
def outbound_edges(self) -> Dict[Vertex, List[Edge]]:
res: Dict[Vertex, List[Edge]] = {}
for v in self.vertices:
res[v] = []

by_id = self._vertices_by_id()
for edge in self.edges:
res[by_id[edge.src]].append(edge)

return res

@property
@cache
def source_vertices(self) -> List[Vertex]:
res: List[Vertex] = []

for v, upstream in self.ancestors().items():
if len(upstream) != 0:
continue

res.append(v)

return res

@property
@cache
def sink_vertices(self) -> List[Vertex]:
res: List[Vertex] = []

for v, downstream in self.descendants().items():
if len(downstream) != 0:
continue

res.append(v)

return res

@classmethod
def _resolve_subworkflows_helper(
cls,
wf_name: str,
dags: Dict[str, Self],
sub_wf_dependencies: Dict[str, List[str]],
):
for dep in sub_wf_dependencies[wf_name]:
cls._resolve_subworkflows_helper(dep, dags, sub_wf_dependencies)

dag = dags[wf_name]

new_vertices: List[Vertex] = []
new_edges: List[Edge] = []
for v in dag.vertices:
if v.type != VertexType.SubWorkflow:
new_vertices.append(v)
continue

sub_dag = dags[v.label]
for sub_v in sub_dag.vertices:
new_vertices.append(
Vertex(
id="_".join([v.id, sub_v.id]),
label=sub_v.label,
type=sub_v.type,
)
)

for sub_e in sub_dag.edges:
new_edges.append(
Edge(
label=sub_e.label,
src="_".join([v.id, sub_e.src]),
dest="_".join([v.id, sub_e.dest]),
)
)

ids = set(v.id for v in new_vertices)
for e in dag.edges:
if e.src in ids:
srcs = [e.src]
else:
sub_dag = dags[dag.src(e).label]

srcs = ["_".join([e.src, v.id]) for v in sub_dag.sink_vertices]

if e.dest in ids:
dests = [e.dest]
else:
sub_dag = dags[dag.dest(e).label]

dests = ["_".join([e.dest, v.id]) for v in sub_dag.source_vertices]

for src in srcs:
for dest in dests:
new_edges.append(
Edge(
label=e.label,
src=src,
dest=dest,
)
)

dags[wf_name] = cls(new_vertices, new_edges)

@classmethod
def resolve_subworkflows(cls, dags: Dict[str, Self]) -> Dict[str, Self]:
dependencies: Dict[str, List[str]] = {}
sources = set(dags.keys())

for wf_name, dag in dags.items():
deps: List[str] = []
for v in dag.vertices:
if v.type != VertexType.SubWorkflow:
continue

deps.append(v.label)
sources.discard(v.label)

dependencies[wf_name] = deps

# todo(ayush): idk the time/space complexity of this but its certainly not great
resolved_dags = dags.copy()
res: Dict[str, Self] = {}
for source in sources:
cls._resolve_subworkflows_helper(source, resolved_dags, dependencies)
res[source] = resolved_dags[source]

return res

def _toposort_helper(self, cur: Vertex, res: List[Vertex]):
for x in self.ancestors()[cur]:
self._toposort_helper(x, res)

res.append(cur)

@cache
def toposorted(self) -> List[Vertex]:
res = []

for sink in self.sink_vertices:
self._toposort_helper(sink, res)

return res
2 changes: 1 addition & 1 deletion latch_cli/extras/nextflow/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ def serialize_nf(
image_name: str,
dkr_repo: str,
):
serialize(nf_wf, output_dir, image_name, dkr_repo, write_spec=True)
serialize(nf_wf, output_dir, image_name, dkr_repo)
Loading

0 comments on commit afb4d96

Please sign in to comment.