From d014f5dc5fecaab1c3af6d0e54dcfa1f60863edc Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 6 May 2024 09:59:18 -0400 Subject: [PATCH] Compute source maps when pretty-printing jaxprs. This change is in preparation for adding support for emitting https://tc39.es/source-map/ source map information for jaxprs, so the relationship between a jaxpr and its Python code can be visualized using tooling for that purpose. This change adds a new `source_map()` pretty printer document, which causes the pretty-printer to populate a source_map side output during pretty printing. The change also teaches the core jaxpr pretty printer to populate source map information on each equation. --- jax/_src/core.py | 104 +++++++++++++++++++---------------- jax/_src/pretty_printer.py | 92 ++++++++++++++++++++++++++----- tests/BUILD | 9 +++ tests/pretty_printer_test.py | 36 ++++++++++++ 4 files changed, 181 insertions(+), 60 deletions(-) create mode 100644 tests/pretty_printer_test.py diff --git a/jax/_src/core.py b/jax/_src/core.py index e845bbcaa4ff..e34ddb8cf3e6 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -149,51 +149,11 @@ def __str__(self): def pretty_print(self, *, source_info=False, print_shapes=True, custom_pp_eqn_rules=True, name_stack=False, print_effects: bool = False, **kwargs): - context = JaxprPpContext() - settings = JaxprPpSettings( - source_info=source_info, - print_shapes=print_shapes, - custom_pp_eqn_rules=custom_pp_eqn_rules, - name_stack=name_stack, - print_effects=print_effects) - - # Compute how many times each jaxpr is used. - names = defaultdict[Jaxpr, str](lambda: "jaxpr") - jaxpr_counts = Counter[Jaxpr]() - s = deque([self]) - while s: - jaxpr = s.popleft() - jaxpr_counts[jaxpr] += 1 - for eqn in jaxpr.eqns: - # TODO(slebedev): Come up with a more elaborate heuristic for name=. - name = eqn.params.get("name") - if name is None: - s.extend(jaxprs_in_params(eqn.params)) - continue - name = name.strip("<>") # -> lambda - for subjaxpr in jaxprs_in_params(eqn.params): - s.append(subjaxpr) - names.setdefault(subjaxpr, name) - - # Pull jaxprs occurring more than once to the top-level, making sure - # that their names are unique. - docs = [] - name_counts = Counter[str]() - for jaxpr, c in jaxpr_counts.items(): - if c == 1: - continue - name = names[jaxpr] - if (count := name_counts[name]) > 0: - name_counts[name] += 1 - name += str(count) - name_counts[name] += 1 - else: - name_counts[name] += 1 - docs.append(pp_top_level_jaxpr(name, jaxpr, context, settings)) - context.used_names.add(name) - context.top_level_jaxprs[jaxpr] = name - docs.append(pp_jaxpr(self, context, settings)) - return pp.concat(docs).format(**kwargs) + doc = pp_toplevel_jaxpr( + self, source_info=source_info, print_shapes=print_shapes, + custom_pp_eqn_rules=custom_pp_eqn_rules, name_stack=name_stack, + print_effects=print_effects) + return doc.format(**kwargs) def _repr_pretty_(self, p, cycle): return p.text(self.pretty_print(use_color=True)) @@ -212,6 +172,7 @@ def replace(self, **kwargs): return jaxpr + def join_effects(*effects: Effects) -> Effects: return set().union(*effects) if effects else no_effects @@ -3164,6 +3125,55 @@ def _check_map(ctx_factory, prim, in_avals, params): # ------------------- Jaxpr printed representation ------------------- +def pp_toplevel_jaxpr(jaxpr_to_print, *, source_info=False, print_shapes=True, + custom_pp_eqn_rules=True, name_stack=False, + print_effects: bool = False) -> pp.Doc: + context = JaxprPpContext() + settings = JaxprPpSettings( + source_info=source_info, + print_shapes=print_shapes, + custom_pp_eqn_rules=custom_pp_eqn_rules, + name_stack=name_stack, + print_effects=print_effects) + + # Compute how many times each jaxpr is used. + names = defaultdict[Jaxpr, str](lambda: "jaxpr") + jaxpr_counts = Counter[Jaxpr]() + s = deque([jaxpr_to_print]) + while s: + jaxpr = s.popleft() + jaxpr_counts[jaxpr] += 1 + for eqn in jaxpr.eqns: + # TODO(slebedev): Come up with a more elaborate heuristic for name=. + name = eqn.params.get("name") + if name is None: + s.extend(jaxprs_in_params(eqn.params)) + continue + name = name.strip("<>") # -> lambda + for subjaxpr in jaxprs_in_params(eqn.params): + s.append(subjaxpr) + names.setdefault(subjaxpr, name) + + # Pull jaxprs occurring more than once to the top-level, making sure + # that their names are unique. + docs = [] + name_counts = Counter[str]() + for jaxpr, c in jaxpr_counts.items(): + if c == 1: + continue + name = names[jaxpr] + if (count := name_counts[name]) > 0: + name_counts[name] += 1 + name += str(count) + name_counts[name] += 1 + else: + name_counts[name] += 1 + docs.append(pp_top_level_jaxpr(name, jaxpr, context, settings)) + context.used_names.add(name) + context.top_level_jaxprs[jaxpr] = name + docs.append(pp_jaxpr(jaxpr_to_print, context, settings)) + return pp.concat(docs) + class JaxprPpSettings(NamedTuple): print_shapes: bool = True @@ -3253,7 +3263,9 @@ def pp_eqn(eqn: JaxprEqn, context: JaxprPpContext, settings: JaxprPpSettings ) -> pp.Doc: rule = (_pp_eqn if not settings.custom_pp_eqn_rules else pp_eqn_rules.get(eqn.primitive, _pp_eqn)) - return rule(eqn, context, settings) # type: ignore[operator] + doc = rule(eqn, context, settings) # type: ignore[operator] + user_frame = source_info_util.user_frame(eqn.source_info) + return doc if user_frame is None else pp.source_map(doc, user_frame) def _pp_eqn(eqn, context, settings, params=None) -> pp.Doc: annotation = (source_info_util.summarize(eqn.source_info) diff --git a/jax/_src/pretty_printer.py b/jax/_src/pretty_printer.py index ec5c34ab846a..0614bb8a8d9b 100644 --- a/jax/_src/pretty_printer.py +++ b/jax/_src/pretty_printer.py @@ -31,7 +31,7 @@ import enum from functools import partial import sys -from typing import NamedTuple +from typing import Any, NamedTuple from jax._src import config from jax._src import util @@ -69,12 +69,23 @@ def _can_use_color() -> bool: class Doc(util.StrictABC): __slots__ = () - def format(self, width: int = 80, use_color: bool | None = None, - annotation_prefix=" # ") -> str: + def format( + self, width: int = 80, *, use_color: bool | None = None, + annotation_prefix: str = " # ", + source_map: list[list[tuple[int, int, Any]]] | None = None + ) -> str: + """ + Formats a pretty-printer document as a string. + + Args: + source_map: for each line in the output, contains a list of + (start column, end column, source) tuples. Each tuple associates a + region of output text with a source. + """ if use_color is None: use_color = CAN_USE_COLOR and _PPRINT_USE_COLOR.value return _format(self, width, use_color=use_color, - annotation_prefix=annotation_prefix) + annotation_prefix=annotation_prefix, source_map=source_map) def __str__(self): return self.format() @@ -147,6 +158,21 @@ def __init__(self, n: int, child: Doc): def __repr__(self): return f"nest({self.n, self.child})" +_NO_SOURCE = object() + +class _SourceMapDoc(Doc): + __slots__ = ("child", "source") + child: Doc + source: Any + + def __init__(self, child: Doc, source: Any): + assert isinstance(child, Doc), child + self.child = child + self.source = source + + def __repr__(self): return f"source({self.child}, {self.source})" + + Color = enum.Enum("_Color", ["BLACK", "RED", "GREEN", "YELLOW", "BLUE", "MAGENTA", "CYAN", "WHITE", "RESET"]) Intensity = enum.Enum("_Intensity", ["DIM", "NORMAL", "BRIGHT"]) @@ -193,7 +219,7 @@ def _fits(doc: Doc, width: int, agenda: list[tuple[int, _BreakMode, Doc]] agenda.append((i + doc.n, m, doc.child)) elif isinstance(doc, _GroupDoc): agenda.append((i, _BreakMode.FLAT, doc.child)) - elif isinstance(doc, _ColorDoc): + elif isinstance(doc, _ColorDoc) or isinstance(doc, _SourceMapDoc): agenda.append((i, m, doc.child)) else: raise ValueError("Invalid document ", doc) @@ -224,7 +250,7 @@ def _sparse(doc: Doc) -> bool: agenda.append(doc.child) elif isinstance(doc, _GroupDoc): agenda.append(doc.child) - elif isinstance(doc, _ColorDoc): + elif isinstance(doc, _ColorDoc) or isinstance(doc, _SourceMapDoc): agenda.append(doc.child) else: raise ValueError("Invalid document ", doc) @@ -241,6 +267,7 @@ class _State(NamedTuple): mode: _BreakMode doc: Doc color: _ColorState + source_map: Any class _Line(NamedTuple): text: str @@ -283,17 +310,29 @@ def _align_annotations(lines): -def _format(doc: Doc, width: int, *, use_color, annotation_prefix) -> str: +def _format( + doc: Doc, width: int, *, use_color: bool, annotation_prefix: str, + source_map: list[list[tuple[int, int, Any]]] | None +) -> str: lines = [] default_colors = _ColorState(Color.RESET, Color.RESET, Intensity.NORMAL) annotation_colors = _ColorState(Color.RESET, Color.RESET, Intensity.DIM) color_state = default_colors - agenda = [_State(0, _BreakMode.BREAK, doc, default_colors)] + source_start = 0 # The column at which the current source region starts. + source = _NO_SOURCE # The currently active source region. + line_source_map = [] # Source maps for the current line of text. + agenda = [_State(0, _BreakMode.BREAK, doc, default_colors, source)] k = 0 line_text = "" line_annotations = [] while len(agenda) > 0: - i, m, doc, color = agenda.pop() + i, m, doc, color, agenda_source = agenda.pop() + if source_map is not None and agenda_source != source: + pos = len(line_text) + if source_start != pos and source is not _NO_SOURCE: + line_source_map.append((source_start, pos, source)) + source = agenda_source + source_start = pos if isinstance(doc, _NilDoc): pass elif isinstance(doc, _TextDoc): @@ -304,7 +343,7 @@ def _format(doc: Doc, width: int, *, use_color, annotation_prefix) -> str: line_annotations.append(doc.annotation) k += len(doc.text) elif isinstance(doc, _ConcatDoc): - agenda.extend(_State(i, m, d, color) + agenda.extend(_State(i, m, d, color, source) for d in reversed(doc.children)) elif isinstance(doc, _BreakDoc): if m == _BreakMode.BREAK: @@ -313,6 +352,13 @@ def _format(doc: Doc, width: int, *, use_color, annotation_prefix) -> str: annotation_colors) line_text += color_str lines.append(_Line(line_text, k, line_annotations)) + if source_map is not None: + pos = len(line_text) + if source_start != pos and source is not _NO_SOURCE: + line_source_map.append((source_start, pos, source)) + source_map.append(line_source_map) + line_source_map = [] + source_start = i line_text = " " * i line_annotations = [] k = i @@ -322,20 +368,22 @@ def _format(doc: Doc, width: int, *, use_color, annotation_prefix) -> str: line_text += doc.text k += len(doc.text) elif isinstance(doc, _NestDoc): - agenda.append(_State(i + doc.n, m, doc.child, color)) + agenda.append(_State(i + doc.n, m, doc.child, color, source)) elif isinstance(doc, _GroupDoc): # In Lindig's paper, _fits is passed the remainder of the document. # I'm pretty sure that's a bug and we care only if the current group fits! if (_sparse(doc) and _fits(doc, width - k, [(i, _BreakMode.FLAT, doc.child)])): - agenda.append(_State(i, _BreakMode.FLAT, doc.child, color)) + agenda.append(_State(i, _BreakMode.FLAT, doc.child, color, source)) else: - agenda.append(_State(i, _BreakMode.BREAK, doc.child, color)) + agenda.append(_State(i, _BreakMode.BREAK, doc.child, color, source)) elif isinstance(doc, _ColorDoc): color = _ColorState(doc.foreground or color.foreground, doc.background or color.background, doc.intensity or color.intensity) - agenda.append(_State(i, m, doc.child, color)) + agenda.append(_State(i, m, doc.child, color, source)) + elif isinstance(doc, _SourceMapDoc): + agenda.append(_State(i, m, doc.child, color, doc.source)) else: raise ValueError("Invalid document ", doc) @@ -343,6 +391,11 @@ def _format(doc: Doc, width: int, *, use_color, annotation_prefix) -> str: color_state, color_str = _update_color(use_color, color_state, annotation_colors) line_text += color_str + if source_map is not None: + pos = len(line_text) + if source_start != pos and source is not _NO_SOURCE: + line_source_map.append((source_start, pos, source)) + source_map.append(line_source_map) lines.append(_Line(line_text, k, line_annotations)) lines = _align_annotations(lines) out = "\n".join( @@ -406,6 +459,17 @@ def color(doc: Doc, *, foreground: Color | None = None, intensity=intensity) +def source_map(doc: Doc, source: Any): + """Source mapping. + + A source map associates a region of the pretty-printer's text output with a + source location that produced it. For the purposes of the pretty printer a + ``source`` may be any object: we require only that we can compare sources for + equality. A text region to source object mapping can be populated as a side + output of the ``format`` method. + """ + return _SourceMapDoc(doc, source) + type_annotation = partial(color, intensity=Intensity.NORMAL, foreground=Color.MAGENTA) keyword = partial(color, intensity=Intensity.BRIGHT, foreground=Color.BLUE) diff --git a/tests/BUILD b/tests/BUILD index f0b8f7570ef3..0dfeda8a2269 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -1477,6 +1477,15 @@ jax_test( ], ) +py_test( + name = "pretty_printer_test", + srcs = ["pretty_printer_test.py"], + deps = [ + "//jax", + "//jax:test_util", + ], +) + exports_files( [ "api_test.py", diff --git a/tests/pretty_printer_test.py b/tests/pretty_printer_test.py new file mode 100644 index 000000000000..d87708c9d91c --- /dev/null +++ b/tests/pretty_printer_test.py @@ -0,0 +1,36 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest + +from jax._src import test_util as jtu +from jax._src import pretty_printer as pp + + +class PrettyPrinterTest(jtu.JaxTestCase): + + def testSourceMap(self): + doc = pp.concat([ + pp.text("abc"), pp.source_map(pp.text("def"), 101), + pp.source_map(pp.concat([pp.text("gh"), pp.brk(""), pp.text("ijkl")]), 77), + pp.text("mn"), + ]) + source_map = [] + out = doc.format(width=8, source_map=source_map) + self.assertEqual(out, "abcdefgh\nijklmn") + self.assertEqual(source_map, [[(3, 6, 101), (6, 8, 77)], [(0, 4, 77)]]) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader())