diff --git a/latch_cli/extras/common/serialize.py b/latch_cli/extras/common/serialize.py index 05a1b402..f43cf7ea 100644 --- a/latch_cli/extras/common/serialize.py +++ b/latch_cli/extras/common/serialize.py @@ -327,6 +327,9 @@ def serialize( + [admin_lp] ] + click.secho("\nSerializing workflow entities", bold=True) + persist_registrable_entities(registrable_entities, output_dir) + if not write_spec: return @@ -352,9 +355,6 @@ def serialize( cur.parent.mkdir(parents=True, exist_ok=True) cur.write_text(MessageToJson(entity)) - click.secho("\nSerializing workflow entities", bold=True) - persist_registrable_entities(registrable_entities, output_dir) - def binding_data_from_python( expected_literal_type: type_models.LiteralType, diff --git a/latch_cli/extras/nextflow/dag.py b/latch_cli/extras/nextflow/dag.py new file mode 100644 index 00000000..7d157361 --- /dev/null +++ b/latch_cli/extras/nextflow/dag.py @@ -0,0 +1,300 @@ +try: + from functools import cache +except ImportError: + from functools import lru_cache as cache + +import json +from dataclasses import dataclass, field +from enum import Enum +from pathlib import Path +from typing import Dict, List, Optional, Set, Tuple, TypedDict + +from typing_extensions import NotRequired, Self + +from latch_cli.utils import identifier_from_str + + +class VertexType(str, Enum): + Process = "Process" + Operator = "Operator" + SubWorkflow = "SubWorkflow" + Conditional = "Conditional" + Generator = "Generator" + Input = "Input" + + +@dataclass(frozen=True) +class Vertex: + id: str + label: str + type: VertexType + + +@dataclass(frozen=True) +class Edge: + label: str + src: str + dest: str + branch: Optional[bool] = None + + +class _VertexContentJson(TypedDict): + id: str + label: str + type: VertexType + processMeta: Optional[Dict] + + +class _VertexJson(TypedDict): + content: _VertexContentJson + + +class _EdgeContentJson(TypedDict): + label: str + src: str + dest: str + branch: NotRequired[Optional[bool]] + + +class _EdgeJson(TypedDict): + content: _EdgeContentJson + + +class _DAGJson(TypedDict): + vertices: List[_VertexJson] + edges: List[_EdgeJson] + + +@dataclass(frozen=True) +class DAG: + vertices: List[Vertex] = field(hash=False) + edges: List[Edge] = field(hash=False) + + @classmethod + def from_path(cls, p: Path) -> Self: + if not p.exists(): + raise # todo(ayush): better errors + + payload: _DAGJson = json.loads(p.read_text()) + + vertices: List[Vertex] = [] + for v in payload["vertices"]: + c = v["content"] + vertices.append( + Vertex( + id=c["id"], + label=identifier_from_str(c["label"])[:128], + type=c["type"], + ) + ) + + edges: List[Edge] = [] + edge_set: Set[Tuple[str, str]] = set() + for e in payload["edges"]: + c = e["content"] + t = (c["src"], c["dest"]) + + if t in edge_set: + # disallow multiple edges + continue + + edges.append(Edge(**c)) + edge_set.add(t) + + return cls(vertices, edges) + + @cache + def _vertices_by_id(self) -> Dict[str, Vertex]: + res: Dict[str, Vertex] = {} + for v in self.vertices: + res[v.id] = v + + return res + + @cache + def src(self, e: Edge) -> Vertex: + return self._vertices_by_id()[e.src] + + @cache + def dest(self, e: Edge) -> Vertex: + return self._vertices_by_id()[e.dest] + + @cache + def ancestors(self) -> Dict[Vertex, List[Vertex]]: + res: Dict[Vertex, List[Vertex]] = {} + for v in self.vertices: + res[v] = [] + + by_id = self._vertices_by_id() + for edge in self.edges: + res[by_id[edge.dest]].append(by_id[edge.src]) + + return res + + @cache + def inbound_edges(self) -> Dict[Vertex, List[Edge]]: + res: Dict[Vertex, List[Edge]] = {} + for v in self.vertices: + res[v] = [] + + by_id = self._vertices_by_id() + for edge in self.edges: + res[by_id[edge.dest]].append(edge) + + return res + + @cache + def descendants(self) -> Dict[Vertex, List[Vertex]]: + res: Dict[Vertex, List[Vertex]] = {} + for v in self.vertices: + res[v] = [] + + by_id = self._vertices_by_id() + for edge in self.edges: + res[by_id[edge.src]].append(by_id[edge.dest]) + + return res + + @cache + def outbound_edges(self) -> Dict[Vertex, List[Edge]]: + res: Dict[Vertex, List[Edge]] = {} + for v in self.vertices: + res[v] = [] + + by_id = self._vertices_by_id() + for edge in self.edges: + res[by_id[edge.src]].append(edge) + + return res + + @property + @cache + def source_vertices(self) -> List[Vertex]: + res: List[Vertex] = [] + + for v, upstream in self.ancestors().items(): + if len(upstream) != 0: + continue + + res.append(v) + + return res + + @property + @cache + def sink_vertices(self) -> List[Vertex]: + res: List[Vertex] = [] + + for v, downstream in self.descendants().items(): + if len(downstream) != 0: + continue + + res.append(v) + + return res + + @classmethod + def _resolve_subworkflows_helper( + cls, + wf_name: str, + dags: Dict[str, Self], + sub_wf_dependencies: Dict[str, List[str]], + ): + for dep in sub_wf_dependencies[wf_name]: + cls._resolve_subworkflows_helper(dep, dags, sub_wf_dependencies) + + dag = dags[wf_name] + + new_vertices: List[Vertex] = [] + new_edges: List[Edge] = [] + for v in dag.vertices: + if v.type != VertexType.SubWorkflow: + new_vertices.append(v) + continue + + sub_dag = dags[v.label] + for sub_v in sub_dag.vertices: + new_vertices.append( + Vertex( + id="_".join([v.id, sub_v.id]), + label=sub_v.label, + type=sub_v.type, + ) + ) + + for sub_e in sub_dag.edges: + new_edges.append( + Edge( + label=sub_e.label, + src="_".join([v.id, sub_e.src]), + dest="_".join([v.id, sub_e.dest]), + ) + ) + + ids = set(v.id for v in new_vertices) + for e in dag.edges: + if e.src in ids: + srcs = [e.src] + else: + sub_dag = dags[dag.src(e).label] + + srcs = ["_".join([e.src, v.id]) for v in sub_dag.sink_vertices] + + if e.dest in ids: + dests = [e.dest] + else: + sub_dag = dags[dag.dest(e).label] + + dests = ["_".join([e.dest, v.id]) for v in sub_dag.source_vertices] + + for src in srcs: + for dest in dests: + new_edges.append( + Edge( + label=e.label, + src=src, + dest=dest, + ) + ) + + dags[wf_name] = cls(new_vertices, new_edges) + + @classmethod + def resolve_subworkflows(cls, dags: Dict[str, Self]) -> Dict[str, Self]: + dependencies: Dict[str, List[str]] = {} + sources = set(dags.keys()) + + for wf_name, dag in dags.items(): + deps: List[str] = [] + for v in dag.vertices: + if v.type != VertexType.SubWorkflow: + continue + + deps.append(v.label) + sources.discard(v.label) + + dependencies[wf_name] = deps + + # todo(ayush): idk the time/space complexity of this but its certainly not great + resolved_dags = dags.copy() + res: Dict[str, Self] = {} + for source in sources: + cls._resolve_subworkflows_helper(source, resolved_dags, dependencies) + res[source] = resolved_dags[source] + + return res + + def _toposort_helper(self, cur: Vertex, res: List[Vertex]): + for x in self.ancestors()[cur]: + self._toposort_helper(x, res) + + res.append(cur) + + @cache + def toposorted(self) -> List[Vertex]: + res = [] + + for sink in self.sink_vertices: + self._toposort_helper(sink, res) + + return res diff --git a/latch_cli/extras/nextflow/serialize.py b/latch_cli/extras/nextflow/serialize.py index c66df473..53cff0a9 100644 --- a/latch_cli/extras/nextflow/serialize.py +++ b/latch_cli/extras/nextflow/serialize.py @@ -8,4 +8,4 @@ def serialize_nf( image_name: str, dkr_repo: str, ): - serialize(nf_wf, output_dir, image_name, dkr_repo, write_spec=True) + serialize(nf_wf, output_dir, image_name, dkr_repo) diff --git a/latch_cli/extras/nextflow/workflow.py b/latch_cli/extras/nextflow/workflow.py index 7d1ce497..505543c8 100644 --- a/latch_cli/extras/nextflow/workflow.py +++ b/latch_cli/extras/nextflow/workflow.py @@ -1,3 +1,4 @@ +import glob import importlib import json import os @@ -62,29 +63,24 @@ from latch.resources.tasks import custom_task from latch.types import metadata from latch.types.metadata import NextflowFileParameter, ParameterType, _IsDataclass -from latch_cli.utils import identifier_from_str +from ...menus import select_tui +from ...utils import identifier_from_str from ..common.serialize import binding_from_python from ..common.utils import reindent, type_repr +from .dag import DAG, VertexType from .types import ( NextflowDAGEdge, NextflowDAGVertex, NextflowInputParamType, NextflowOutputParamType, NextflowParam, - VertexType, ) from .utils import format_param_name class NextflowWorkflow(WorkflowBase, ClassStorageTaskResolver): - def __init__( - self, - vertices: Dict[int, NextflowDAGVertex], - dependent_vertices: Dict[int, List[int]], - dependent_edges_by_start: Dict[int, List[NextflowDAGEdge]], - dependent_edges_by_end: Dict[int, List[NextflowDAGEdge]], - ): + def __init__(self, dag: DAG): # todo(ayush): consolidate w/ snakemake assert metadata._nextflow_metadata is not None @@ -129,22 +125,11 @@ def __init__( self.nextflow_tasks: List[NextflowTask] = [] - self.build_from_nextflow_dag( - vertices, - dependent_vertices, - dependent_edges_by_start, - dependent_edges_by_end, - python_interface, - ) + self.dag = dag - def build_from_nextflow_dag( - self, - vertices: Dict[int, NextflowDAGVertex], - dependent_vertices: Dict[int, List[int]], - dependent_edges_by_start: Dict[int, List[NextflowDAGEdge]], - dependent_edges_by_end: Dict[int, List[NextflowDAGEdge]], - python_interface: Interface, - ): + self.build_from_nextflow_dag() + + def build_from_nextflow_dag(self): global_start_node = Node( id=_common_constants.GLOBAL_INPUT_NODE_ID, metadata=None, @@ -153,10 +138,10 @@ def build_from_nextflow_dag( flyte_entity=None, ) - interface_inputs = transform_variable_map(python_interface.inputs) + interface_inputs = transform_variable_map(self.python_interface.inputs) main_task_bindings = [] - for k in python_interface.inputs: + for k in self.python_interface.inputs: var = interface_inputs[k] promise_to_bind = Promise( var=k, @@ -171,19 +156,10 @@ def build_from_nextflow_dag( ) ) - main_task_interface = Interface(python_interface.inputs, None, docstring=None) - - main_task = NextflowMainTask(interface=main_task_interface, wf=self) - main_node = Node( - id="main", - metadata=main_task.construct_node_metadata(), - bindings=sorted(main_task_bindings, key=lambda x: x.var), - upstream_nodes=[], - flyte_entity=main_task, - ) - # wf input files that need to be downloaded into every task - global_wf_inputs = {f"wf_{k}": v for k, v in python_interface.inputs.items()} + global_wf_inputs = { + f"wf_{k}": v for k, v in self.python_interface.inputs.items() + } global_wf_input_bindings = [ binding_from_python( var_name=f"wf_{k}", @@ -194,72 +170,51 @@ def build_from_nextflow_dag( ), t_value_type=interface_inputs[k], ) - for k, v in python_interface.inputs.items() + for k, v in self.python_interface.inputs.items() ] - node_map: Dict[int, Node] = {} - extra_nodes: List[Node] = [main_node] - main_node_outputs: Dict[str, Type[ParameterType]] = {} - for vertex_id in sorted(dependent_vertices.keys()): - vertex = vertices[vertex_id] - if vertex.vertex_type == VertexType.origin: - continue + node_map: Dict[str, Node] = {} + extra_nodes: List[Node] = [] + for vertex in self.dag.toposorted(): upstream_nodes = [global_start_node] - bindings: List[literals_models.Binding] = [] - for edge in dependent_edges_by_end[vertex.id]: - depen_vertex = vertices.get(edge.connection[0]) - param_name = f"c{edge.from_idx}" - - if ( - depen_vertex is None - or depen_vertex.vertex_type == VertexType.origin - ): - if vertex.id not in main_task.main_target_ids: - main_task.main_target_ids.append(vertex.id) - - main_out_param_name = f"v{vertex.id}_c{edge.to_idx}" - node_output = NodeOutput(node=main_node, var=main_out_param_name) - - main_node_outputs[main_out_param_name] = List[str] + + task_inputs = {**global_wf_inputs} + task_outputs = {"default": List[str]} + + task_bindings: List[literals_models.Binding] = [*global_wf_input_bindings] + for dep in self.dag.ancestors()[vertex]: + if dep.type == VertexType.Conditional: + param_name = f"condition_{dep.id}" + task_inputs[param_name] = bool + + node = NodeOutput( + node=node_map[dep.id], + var=f"condition", + ) else: - source_param = f"c{edge.to_idx}" - node_output = NodeOutput( - node=node_map[depen_vertex.id], - var=source_param, + param_name = f"c{dep.id}" + task_inputs[param_name] = List[str] + + node = NodeOutput( + node=node_map[dep.id], + var=f"default", ) - promise_to_bind = Promise( - var=param_name, - val=node_output, - ) - bindings.append( + task_bindings.append( literals_models.Binding( var=param_name, binding=literals_models.BindingData( - promise=promise_to_bind.ref + promise=Promise(var=param_name, val=node).ref ), ) ) - if depen_vertex and depen_vertex.id in node_map: - upstream_nodes.append(node_map[depen_vertex.id]) + upstream_nodes.append(node_map[dep.id]) - python_inputs = { - f"c{e.from_idx}": List[str] - for e in dependent_edges_by_end.get(vertex.id, []) - } - python_inputs = {**python_inputs, **global_wf_inputs} - bindings = [*bindings, *global_wf_input_bindings] - - python_outputs = { - f"c{e.to_idx}": List[str] - for e in dependent_edges_by_start.get(vertex.id, []) - } - - if vertex.vertex_type == VertexType.process: + if vertex.type == VertexType.Process: pre_adapter_task = NextflowProcessPreAdapterTask( - inputs=python_inputs, + inputs=task_inputs, id=f"{vertex.id}_pre", name=f"pre_adapter_{identifier_from_str(vertex.label)}", wf=self, @@ -269,14 +224,14 @@ def build_from_nextflow_dag( pre_adapter_node = Node( id=f"n{vertex.id}-pre-adapter", metadata=pre_adapter_task.construct_node_metadata(), - bindings=sorted(bindings, key=lambda b: b.var), + bindings=sorted(task_bindings, key=lambda b: b.var), upstream_nodes=upstream_nodes, flyte_entity=pre_adapter_task, ) extra_nodes.append(pre_adapter_node) post_adapter_task = NextflowProcessPostAdapterTask( - outputs=python_outputs, + outputs=task_outputs, id=f"{vertex.id}_post", name=f"post_adapter_{identifier_from_str(vertex.label)}", wf=self, @@ -321,7 +276,7 @@ def parse_dataclass( num_outputs=num_outputs, id=vertex.id, name=identifier_from_str(vertex.label), - code=vertex.code, + code=vertex.label, # todo(ayush) wf=self, ) @@ -377,10 +332,30 @@ def parse_dataclass( node_map[vertex.id] = post_adapter_node - elif vertex.vertex_type == VertexType.operator: + elif vertex.type in VertexType.Conditional: + conditional_task = NextflowConditionalTask( + task_inputs, + vertex.id, + f"Conditional: {vertex.label}", + self, + ) + self.nextflow_tasks.append(conditional_task) + + node = Node( + id=f"n{vertex.id}", + metadata=conditional_task.construct_node_metadata(), + bindings=task_bindings, + upstream_nodes=upstream_nodes, + flyte_entity=conditional_task, + ) + + node_map[vertex.id] = node + + # elif vertex.type == VertexType.Operator: + else: operator_task = NextflowOperatorTask( - inputs=python_inputs, - outputs=python_outputs, + inputs=task_inputs, + outputs=task_outputs, name=vertex.label, id=vertex.id, wf=self, @@ -390,28 +365,20 @@ def parse_dataclass( node = Node( id=f"n{vertex.id}", metadata=operator_task.construct_node_metadata(), - bindings=bindings, + bindings=task_bindings, upstream_nodes=upstream_nodes, flyte_entity=operator_task, ) node_map[vertex.id] = node - elif vertex.vertex_type in {VertexType.origin, VertexType.node}: - # generic channel - ... + # ... - else: - raise ValueError(f"Unsupported vertex type for {repr(vertex)}") + # else: + # raise ValueError(f"Unsupported vertex type for {repr(vertex)}") self._nodes = list(node_map.values()) + extra_nodes - main_task._python_outputs = main_node_outputs - self.main_task = main_task - main_node.flyte_entity._interface._outputs = transform_variable_map( - main_node_outputs - ) - def execute(self, **kwargs): return exception_scopes.user_entry_point(self._workflow_function)(**kwargs) @@ -421,7 +388,7 @@ def __init__( self, inputs: Dict[str, Type[ParameterType]], outputs: Dict[str, Type[ParameterType]], - id: int, + id: str, name: str, wf: NextflowWorkflow, ): @@ -509,6 +476,8 @@ def get_k8s_pod(self, settings: SerializationSettings) -> _task_models.K8sPod: ), ) + def execute(self): ... + class NextflowTaskResolver(DefaultTaskResolver): @property @@ -845,7 +814,7 @@ class NextflowProcessPreAdapterTask(NextflowTask): def __init__( self, inputs: Dict[str, Type[ParameterType]], - id: int, + id: str, name: str, wf: NextflowWorkflow, ): @@ -940,7 +909,7 @@ class NextflowProcessPostAdapterTask(NextflowTask): def __init__( self, outputs: Dict[str, Type[ParameterType]], - id: int, + id: str, name: str, wf: NextflowWorkflow, ): @@ -1023,7 +992,7 @@ def __init__( self, inputs: Dict[str, Type[ParameterType]], outputs: Dict[str, Type[ParameterType]], - id: int, + id: str, name: str, wf: NextflowWorkflow, ): @@ -1169,175 +1138,27 @@ def get_fn_code(self, nf_path_in_container: str): return code_block -class NextflowMainTask(NextflowTask): - def __init__(self, interface: Interface, wf: NextflowWorkflow): - self.main_target_ids = [] - - super().__init__( - name="main", - id=0, - inputs=interface.inputs, - outputs=interface.outputs, - wf=wf, - ) - - def get_fn_interface(self): - res = "" - - outputs_str = "None:" - if len(self._python_outputs.items()) > 0: - output_fields = "\n".join( - reindent( - rf""" - {param}: {type_repr(t)} - """, - 1, - ).rstrip() - for param, t in self._python_outputs.items() - ) - - res += reindent( - rf""" - class Res{self.name}(NamedTuple): - __output_fields__ - - """, - 0, - ).replace("__output_fields__", output_fields) - outputs_str = f"Res{self.name}:" - - params_str = ",\n".join( - reindent( - rf""" - {param}: {type_repr(t)} - """, - 1, - ).rstrip() - for param, t in self._python_inputs.items() - ) - - res += ( - reindent( - rf""" - task = custom_task(cpu=-1, memory=-1) # these limits are a lie and are ignored when generating the task spec - @task(cache=True) - def {self.name}( - __params__ - ) -> __outputs__ - """, - 0, - ) - .replace("__params__", params_str) - .replace("__outputs__", outputs_str) - ) - return res - - def get_fn_return_stmt(self): - results: List[str] = [] - for out_name, out_type in self._python_outputs.items(): - results.append( - reindent( - rf""" - {out_name}=out_channels.get("{out_name}", []) - """, - 2, - ).rstrip() - ) - - return_str = ",\n".join(results) - - return reindent( - rf""" - return Res{self.name}( - __return_str__ - ) - """, - 0, - ).replace("__return_str__", return_str) - - def get_fn_code(self, nf_path_in_container: str): - code_block = self.get_fn_interface() - - run_task_entrypoint = [ - "/root/nextflow", - "run", - nf_path_in_container, - "-profile", - "mamba", - "-latchTarget", - ] - - for flag, val in self.wf.flags_to_params.items(): - run_task_entrypoint.extend([flag, str(val)]) - - for k, v in self.wf.downloadable_params.items(): - code_block += reindent( - f""" - {k}_p = Path({k}).resolve() - {k}_dest_p = Path({repr(v)}).resolve() - - check_exists_and_rename( - {k}_p, - {k}_dest_p - ) - """, - 1, - ) - - code_block += reindent( - rf""" - print(f"\n\n\nRunning nextflow task: {run_task_entrypoint}\n") - try: - subprocess.run( - [{','.join([repr(x) for x in run_task_entrypoint])}], - env={{ - **os.environ, - "LATCH_MAIN_TARGET_IDS": "{json.dumps(self.main_target_ids)}", - }}, - check=True, - ) - except Exception as e: - print("\n\n\n[!] Failed\n\n\n") - raise e - - out_channels = {{}} - files = list(glob.glob(".latch_compiled_channels/*/channel*.txt")) - for file in files: - idx = parse_channel_file(file) - vals = Path(file).read_text().strip().split("\n") - v_id = Path(file).parent.name - out_channels[f"v{{v_id}}_c{{idx}}"] = vals - - - import json - print(json.dumps(out_channels, indent=2)) - - """, - 1, - ) +class NextflowConditionalTask(NextflowOperatorTask): + def __init__( + self, + inputs: Dict[str, Type[ParameterType]], + id: str, + name: str, + wf: NextflowWorkflow, + ): + self.operator_id = id - code_block += self.get_fn_return_stmt() - return code_block + super().__init__(inputs, {"condition": bool}, id, name, wf) -def build_nf_wf(pkg_root: Path, nf_script: Path): +def build_nf_wf(pkg_root: Path, nf_script: Path) -> NextflowWorkflow: try: subprocess.run( [ str(pkg_root / ".latch/bin/nextflow"), "run", str(nf_script), - "-with-dag", - "-latchJIT", - "--input", - str(pkg_root / "assets" / "samplesheet.csv"), - "--outdir", - str(pkg_root), - # "--run_amp_screening", - # "--amp_skip_hmmsearch", - # "--run_arg_screening", - # "--run_bgc_screening", - # "--bgc_skip_hmmsearch", + "-latchRegister", ], check=True, ) @@ -1345,92 +1166,44 @@ def build_nf_wf(pkg_root: Path, nf_script: Path): print("\n\n\n[!] Failed\n\n\n") raise e - with open(pkg_root / ".latch/nextflowDAG.json") as f: - dag = json.load(f) + dags: Dict[str, DAG] = {} - vertices_json = dag["vertices"] - edges_json = dag["edges"] + dag_files = map(Path, glob.glob(".latch/*.dag.json")) + for dag in dag_files: + wf_name = dag.name.rsplit(".", 2)[0] - vertices: Dict[int, NextflowDAGVertex] = {} - dependent_vertices: Dict[int, List[int]] = {} - for v in vertices_json: - content = v["content"] + dags[wf_name] = DAG.from_path(dag) - code: Optional[str] = None - if "source" in content: - code = content["source"] + resolved = DAG.resolve_subworkflows(dags) - input_params: List[NextflowParam] = [] - if content["inputParams"] is not None: - for x in content["inputParams"]: - t = NextflowInputParamType(x["type"]) + if len(resolved) == 0: + click.secho("No Nextflow workflows found in this project. Aborting.", fg="red") - input_params.append( - NextflowParam(name=format_param_name(x["name"], t), type=t) - ) + raise click.exceptions.Exit(1) - output_params: List[NextflowParam] = [] - if content["outputParams"] is not None: - for x in content["outputParams"]: - t = NextflowOutputParamType(x["type"]) + dag = list(resolved.values())[0] - output_params.append( - NextflowParam(name=format_param_name(x["name"], t), type=t) - ) - - vertex = NextflowDAGVertex( - id=content["id"], - label=content["label"], - vertex_type=VertexType(content["type"].lower()), - input_params=input_params, - output_params=output_params, - code=code, - ) - - vertices[vertex.id] = vertex - dependent_vertices[vertex.id] = [] - - dependent_edges_by_start: Dict[int, List[NextflowDAGEdge]] = {} - dependent_edges_by_end: Dict[int, List[NextflowDAGEdge]] = {} - for i in vertices.keys(): - dependent_edges_by_start[i] = [] - dependent_edges_by_end[i] = [] - - edges = [] - for edge_json in edges_json: - edge_content = edge_json["content"] - - edge = NextflowDAGEdge( - id=edge_content["id"], - to_idx=edge_content["outIdx"], - from_idx=edge_content["inIdx"], - label=edge_content["label"], - connection=edge_content["connection"], + if len(resolved) > 1: + dag = select_tui( + "We found multiple independent workflows in this Nextflow project. Which" + " would you like to register?", + [ + { + "display_name": ( + k + " (Anonymous Workflow)" if k == "mainWorkflow" else k + ), + "value": v, + } + for k, v in resolved.items() + ], ) - edges.append(edge) - - if edge.connection[0] is not None: - dependent_edges_by_start[edge.connection[0]].append(edge) + if dag is None: + click.echo("No workflow selected. Aborting.") - if edge.connection[1] is not None: - dependent_edges_by_end[edge.connection[1]].append(edge) + raise click.exceptions.Exit(0) - from_vertex, to_vertex = edge.connection - if to_vertex is not None: - dependent_vertices[to_vertex].append(from_vertex) - - print("Vertices:") - for v in vertices.values(): - print(f" {v}") - - print("Edges:") - for e in edges: - print(f" {e}") - - return NextflowWorkflow( - vertices, dependent_vertices, dependent_edges_by_start, dependent_edges_by_end - ) + return NextflowWorkflow(dag) def nf_path_in_container(nf_script: Path, pkg_root: Path) -> str: @@ -1485,9 +1258,9 @@ def parse_channel_file(fname: str) -> int: """).lstrip() - entrypoint_code_block += wf.main_task.get_fn_code( - nf_path_in_container(nf_path, pkg_root) - ) + # entrypoint_code_block += wf.main_task.get_fn_code( + # nf_path_in_container(nf_path, pkg_root) + # ) for task in wf.nextflow_tasks: entrypoint_code_block += ( diff --git a/latch_cli/services/register/register.py b/latch_cli/services/register/register.py index 1f585fd4..1018bfd6 100644 --- a/latch_cli/services/register/register.py +++ b/latch_cli/services/register/register.py @@ -248,8 +248,6 @@ def _build_and_serialize( from ...extras.nextflow.serialize import serialize_nf serialize_nf(nf_wf, tmp_dir, image_name, ctx.dkr_repo) - - sys.exit(1) else: serialize_logs, container_id = serialize_pkg_in_container( ctx, image_name, tmp_dir