From 17151ca14f25854197449dcdd2a40aaf07f471d6 Mon Sep 17 00:00:00 2001 From: Alisson Gusatti Azzolini Date: Mon, 6 Feb 2017 17:28:19 -0800 Subject: [PATCH] Debug/Analysis tools for Jobs/ExecutionSteps Summary: Introduces 2 utitilies: - ##print_obj##: Prints the whole Job in a nice way -- each op call takes one single line and nets are inlined for much better readability. Loops and parallel steps are easy to read. - ##analyse_obj##: Goes through a Job and checks 2 things: - that there will be no undefined blob errors at execution. - no blob of same name will be created by parallel execution steps Reviewed By: dzhulgakov Differential Revision: D4142381 fbshipit-source-id: 61bf3398c22e9947493e99145ce2bfc2646830a6 --- caffe2/python/net_printer.py | 289 ++++++++++++++++++++++++++++++ caffe2/python/net_printer_test.py | 89 +++++++++ 2 files changed, 378 insertions(+) create mode 100644 caffe2/python/net_printer.py create mode 100644 caffe2/python/net_printer_test.py diff --git a/caffe2/python/net_printer.py b/caffe2/python/net_printer.py new file mode 100644 index 0000000000000..d88dce547a338 --- /dev/null +++ b/caffe2/python/net_printer.py @@ -0,0 +1,289 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +from caffe2.proto.caffe2_pb2 import OperatorDef +from caffe2.python.checkpoint import Job +from caffe2.python.core import Net, ExecutionStep, Plan +from caffe2.python.task import Task, TaskGroup, WorkspaceType +from collections import defaultdict +from contextlib import contextmanager +from copy import copy + + +class Visitor(object): + @classmethod + def register(cls, Type): + if not(hasattr(cls, 'visitors')): + cls.visitors = [] + + def _register(func): + cls.visitors.append((Type, func)) + return func + + return _register + + def __call__(self, obj, *args, **kwargs): + if obj is None: + return + for Type, func in self.__class__.visitors: + if isinstance(obj, Type): + return func(self, obj, *args, **kwargs) + raise TypeError('%s: unsupported object type: %s' % ( + self.__class__.__name__, type(obj))) + + +class Analyzer(Visitor): + PREFIXES_TO_IGNORE = {'distributed_ctx_init'} + + def __init__(self): + self.workspaces = defaultdict(lambda: defaultdict(lambda: 0)) + self.workspace_ctx = [] + + @property + def workspace(self): + return self.workspace_ctx[-1] + + @contextmanager + def set_workspace(self, node=None, ws=None, do_copy=False): + if ws is not None: + ws = ws + elif node is not None: + ws = self.workspaces[str(node)] + else: + ws = self.workspace + if do_copy: + ws = copy(ws) + self.workspace_ctx.append(ws) + yield ws + del self.workspace_ctx[-1] + + def define_blob(self, blob): + self.workspace[blob] += 1 + + def need_blob(self, blob): + if any(blob.startswith(p) for p in Analyzer.PREFIXES_TO_IGNORE): + return + assert blob in self.workspace, 'Blob undefined: %s' % blob + + +@Analyzer.register(OperatorDef) +def analyze_op(analyzer, op): + map(analyzer.need_blob, op.input) + map(analyzer.define_blob, op.output) + + +@Analyzer.register(Net) +def analyze_net(analyzer, net): + map(analyzer, net.Proto().op) + + +@Analyzer.register(ExecutionStep) +def analyze_step(analyzer, step): + proto = step.Proto() + if proto.report_net: + with analyzer.set_workspace(do_copy=True): + analyzer(step.get_net(proto.report_net)) + all_new_blobs = set() + substeps = step.Substeps() + [step.get_net(n) for n in proto.network] + for substep in substeps: + with analyzer.set_workspace(do_copy=proto.concurrent_substeps) as ws_in: + analyzer(substep) + if proto.should_stop_blob: + analyzer.need_blob(proto.should_stop_blob) + if proto.concurrent_substeps: + new_blobs = set(ws_in.keys()) - set(analyzer.workspace.keys()) + assert len(all_new_blobs & new_blobs) == 0, ( + 'Error: Blobs created by multiple parallel steps: %s' % ( + ', '.join(all_new_blobs & new_blobs))) + all_new_blobs |= new_blobs + map(analyzer.define_blob, all_new_blobs) + + +@Analyzer.register(Task) +def analyze_task(analyzer, task): + # check that our plan protobuf is not too large (limit of 64Mb) + step = task.get_step() + plan = Plan(task.node) + plan.AddStep(step) + proto_len = len(plan.Proto().SerializeToString()) + assert proto_len < 2 ** 26, ( + 'Due to a protobuf limitation, serialized tasks must be smaller ' + 'than 64Mb, but this task has {} bytes.' % proto_len) + + is_private = task.workspace_type() != WorkspaceType.GLOBAL + with analyzer.set_workspace(do_copy=is_private): + analyzer(step) + + +@Analyzer.register(TaskGroup) +def analyze_task_group(analyzer, tg): + for task in tg.tasks_by_node().tasks(): + with analyzer.set_workspace(node=task.node): + analyzer(task) + + +@Analyzer.register(Job) +def analyze_job(analyzer, job): + analyzer(job.init_group) + analyzer(job.epoch_group) + + +def analyze(obj): + """ + Given a Job, visits all the execution steps making sure that: + - no undefined blobs will be found during excution + - no blob with same name is defined in concurrent steps + """ + Analyzer()(obj) + + +class Text(object): + def __init__(self): + self._indent = 0 + self._lines_in_context = [0] + self.lines = [] + + @contextmanager + def context(self, text): + if text is not None: + self.add('with %s:' % text) + self._indent += 4 + self._lines_in_context.append(0) + yield + if text is not None: + self._indent -= 4 + if self._lines_in_context[-1] == 0: + self.add('pass') + del self._lines_in_context[-1] + + def add(self, text): + self._lines_in_context[-1] += 1 + self.lines.append((' ' * self._indent) + text) + + def __str__(self): + return '\n'.join(self.lines) + + +class Printer(Visitor, Text): + pass + + +def _sanitize_str(s): + s = str(s) + return s if len(s) < 64 else (s[:64] + '...<+len=%d>' % (len(s) - 64)) + + +def _arg_val(arg): + if arg.HasField('f'): + return str(arg.f) + if arg.HasField('i'): + return str(arg.i) + if arg.HasField('s'): + return _sanitize_str(arg.s) + if arg.floats: + return str(list(arg.floats)) + if arg.ints: + return str(list(arg.ints)) + if arg.strings: + return str([_sanitize_str(s) for s in arg.strings]) + return '[]' + + +def call(op, inputs=None, outputs=None): + inputs = '' if not inputs else ', '.join( + '%s=%s' % (str(a[0]), str(a[1])) if isinstance(a, tuple) else str(a) + for a in inputs) + call = '%s(%s)' % (op, inputs) + return call if not outputs else '%s = %s' % (', '.join(outputs), call) + + +@Printer.register(OperatorDef) +def print_op(text, op): + text.add(call( + op.type, + list(op.input) + [(a.name, _arg_val(a)) for a in op.arg], + op.output)) + + +@Printer.register(Net) +def print_net(text, net): + text.add('# net: %s' % str(net)) + for op in net.Proto().op: + text(op) + + +def _get_step_context(step): + proto = step.Proto() + if proto.should_stop_blob: + return call('loop'), None + if proto.num_iter and proto.num_iter != 1: + return call('loop', [proto.num_iter]), None + concurrent = proto.concurrent_substeps and len(step.Substeps()) > 1 + if concurrent: + return call('parallel'), call('step') + if proto.report_net: + return call('run_once'), None + return None, None + + +@Printer.register(ExecutionStep) +def print_step(text, step): + proto = step.Proto() + step_ctx, substep_ctx = _get_step_context(step) + with text.context(step_ctx): + if proto.report_net: + with text.context(call('report_net', [proto.report_interval])): + text(step.get_net(proto.report_net)) + substeps = step.Substeps() + [step.get_net(n) for n in proto.network] + for substep in substeps: + with text.context(substep_ctx): + text(substep) + if proto.should_stop_blob: + text.add(call('yield stop_if', [proto.should_stop_blob])) + + +@Printer.register(Task) +def print_task(text, task): + with text.context(call('Task', [('node', task.node)])): + text(task.get_step()) + + +@Printer.register(TaskGroup) +def print_task_group(text, tg, header=None): + with text.context(header or call('TaskGroup')): + for task in tg.tasks_by_node().tasks(): + text(task) + + +@Printer.register(Job) +def print_job(text, job): + text(job.init_group, 'Job.current().init_group') + text(job.epoch_group, 'Job.current().epoch_group') + + +def to_string(obj): + """ + Given a Net, ExecutionStep, Task, TaskGroup or Job, produces a string + with detailed description of the execution steps. + """ + printer = Printer() + printer(obj) + return str(printer) + + +def debug_net(net): + """ + Given a Net, produce another net that logs info about the operator call + before each operator execution. Use for debugging purposes. + """ + assert isinstance(net, Net) + debug_net = Net(str(net)) + assert isinstance(net, Net) + for op in net.Proto().op: + text = Text() + print_op(op, text) + debug_net.LogInfo(str(text)) + debug_net.Proto().op.extend([op]) + return debug_net diff --git a/caffe2/python/net_printer_test.py b/caffe2/python/net_printer_test.py new file mode 100644 index 0000000000000..2d6f5a172326c --- /dev/null +++ b/caffe2/python/net_printer_test.py @@ -0,0 +1,89 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +from caffe2.python import net_printer +from caffe2.python.checkpoint import Job +from caffe2.python.net_builder import ops +from caffe2.python.task import Task, final_output +import unittest + + +def example_loop(): + with Task(): + total = ops.Const(0) + total_large = ops.Const(0) + total_small = ops.Const(0) + total_tiny = ops.Const(0) + with ops.loop(10) as loop: + outer = ops.Mul([loop.iter(), ops.Const(10)]) + with ops.loop(loop.iter()) as inner: + val = ops.Add([outer, inner.iter()]) + with ops.If(ops.GE([val, ops.Const(80)])) as c: + ops.Add([total_large, val], [total_large]) + with c.Elif(ops.GE([val, ops.Const(50)])) as c: + ops.Add([total_small, val], [total_small]) + with c.Else(): + ops.Add([total_tiny, val], [total_tiny]) + ops.Add([total, val], total) + + +def example_task(): + with Task(): + with ops.task_init(): + one = ops.Const(1) + two = ops.Add([one, one]) + with ops.task_init(): + three = ops.Const(3) + accum = ops.Add([two, three]) + # here, accum should be 5 + with ops.task_exit(): + # here, accum should be 6, since this executes after lines below + seven_1 = ops.Add([accum, one]) + six = ops.Add([accum, one]) + ops.Add([accum, one], [accum]) + seven_2 = ops.Add([accum, one]) + o6 = final_output(six) + o7_1 = final_output(seven_1) + o7_2 = final_output(seven_2) + return o6, o7_1, o7_2 + + +def example_job(): + with Job() as job: + with job.init_group: + example_loop() + example_task() + return job + + +class TestNetPrinter(unittest.TestCase): + def test_print(self): + self.assertTrue(len(net_printer.to_string(example_job())) > 0) + + def test_valid_job(self): + job = example_job() + with job: + with Task(): + # distributed_ctx_init_* ignored by analyzer + ops.Add(['distributed_ctx_init_a', 'distributed_ctx_init_b']) + net_printer.analyze(example_job()) + + def test_undefined_blob(self): + job = example_job() + with job: + with Task(): + ops.Add(['a', 'b']) + with self.assertRaises(AssertionError): + net_printer.analyze(job) + + def test_multiple_definition(self): + job = example_job() + with job: + with Task(): + ops.Add([ops.Const(0), ops.Const(1)], 'out1') + with Task(): + ops.Add([ops.Const(2), ops.Const(3)], 'out1') + with self.assertRaises(AssertionError): + net_printer.analyze(job)