Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dvc: implement multistage dvcfile #3676

Merged
merged 14 commits into from
Apr 28, 2020
28 changes: 9 additions & 19 deletions dvc/command/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,15 @@
logger = logging.getLogger(__name__)


def _stage_repr(stage):
from dvc.stage import PipelineStage

return (
"{}:{}".format(stage.relpath, stage.name)
if isinstance(stage, PipelineStage)
else stage.relpath
)


class CmdPipelineShow(CmdBase):
def _show(self, target, commands, outs, locked):
import networkx
from dvc import dvcfile
from dvc.utils import parse_target

path, name = parse_target(target)
stage = dvcfile.Dvcfile(self.repo, path).stages[name]
G = self.repo.pipeline_graph
path, name, tag = parse_target(target)
stage = dvcfile.Dvcfile(self.repo, path, tag=tag).stages[name]
G = self.repo.graph
stages = networkx.dfs_postorder_nodes(G, stage)
if locked:
stages = [s for s in stages if s.locked]
Expand All @@ -40,16 +30,16 @@ def _show(self, target, commands, outs, locked):
for out in stage.outs:
logger.info(str(out))
else:
logger.info(_stage_repr(stage))
logger.info(stage.addressing)

def _build_graph(self, target, commands=False, outs=False):
import networkx
from dvc import dvcfile
from dvc.repo.graph import get_pipeline
from dvc.utils import parse_target

path, name = parse_target(target)
target_stage = dvcfile.Dvcfile(self.repo, path).stages[name]
path, name, tag = parse_target(target)
target_stage = dvcfile.Dvcfile(self.repo, path, tag=tag).stages[name]
G = get_pipeline(self.repo.pipelines, target_stage)

nodes = set()
Expand All @@ -62,7 +52,7 @@ def _build_graph(self, target, commands=False, outs=False):
for out in stage.outs:
nodes.add(str(out))
else:
nodes.add(_stage_repr(stage))
nodes.add(stage.addressing)

edges = []
for from_stage, to_stage in networkx.edge_dfs(G, target_stage):
Expand All @@ -75,7 +65,7 @@ def _build_graph(self, target, commands=False, outs=False):
for to_out in to_stage.outs:
edges.append((str(from_out), str(to_out)))
else:
edges.append((_stage_repr(from_stage), _stage_repr(to_stage)))
edges.append((from_stage.addressing, to_stage.addressing))

return list(nodes), edges, networkx.is_tree(G)

Expand Down Expand Up @@ -163,7 +153,7 @@ def run(self):
pipelines = self.repo.pipelines
for pipeline in pipelines:
for stage in pipeline:
logger.info(_stage_repr(stage))
logger.info(stage.addressing)
if len(pipeline) != 0:
logger.info("=" * 80)
logger.info("{} pipelines total".format(len(pipelines)))
Expand Down
Loading