Skip to content

Commit

Permalink
Compute source maps when pretty-printing jaxprs.
Browse files Browse the repository at this point in the history
This change is in preparation for adding support for emitting https://tc39.es/source-map/ source map information for jaxprs, so the relationship between a jaxpr and its Python code can be visualized using tooling for that purpose.

This change adds a new `source_map()` pretty printer document, which causes the pretty-printer to populate a source_map side output during pretty printing.
The change also teaches the core jaxpr pretty printer to populate source map information on each equation.
  • Loading branch information
hawkinsp committed May 6, 2024
1 parent 7681493 commit d014f5d
Show file tree
Hide file tree
Showing 4 changed files with 181 additions and 60 deletions.
104 changes: 58 additions & 46 deletions jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,51 +149,11 @@ def __str__(self):
def pretty_print(self, *, source_info=False, print_shapes=True,
custom_pp_eqn_rules=True, name_stack=False,
print_effects: bool = False, **kwargs):
context = JaxprPpContext()
settings = JaxprPpSettings(
source_info=source_info,
print_shapes=print_shapes,
custom_pp_eqn_rules=custom_pp_eqn_rules,
name_stack=name_stack,
print_effects=print_effects)

# Compute how many times each jaxpr is used.
names = defaultdict[Jaxpr, str](lambda: "jaxpr")
jaxpr_counts = Counter[Jaxpr]()
s = deque([self])
while s:
jaxpr = s.popleft()
jaxpr_counts[jaxpr] += 1
for eqn in jaxpr.eqns:
# TODO(slebedev): Come up with a more elaborate heuristic for name=.
name = eqn.params.get("name")
if name is None:
s.extend(jaxprs_in_params(eqn.params))
continue
name = name.strip("<>") # <lambda> -> lambda
for subjaxpr in jaxprs_in_params(eqn.params):
s.append(subjaxpr)
names.setdefault(subjaxpr, name)

# Pull jaxprs occurring more than once to the top-level, making sure
# that their names are unique.
docs = []
name_counts = Counter[str]()
for jaxpr, c in jaxpr_counts.items():
if c == 1:
continue
name = names[jaxpr]
if (count := name_counts[name]) > 0:
name_counts[name] += 1
name += str(count)
name_counts[name] += 1
else:
name_counts[name] += 1
docs.append(pp_top_level_jaxpr(name, jaxpr, context, settings))
context.used_names.add(name)
context.top_level_jaxprs[jaxpr] = name
docs.append(pp_jaxpr(self, context, settings))
return pp.concat(docs).format(**kwargs)
doc = pp_toplevel_jaxpr(
self, source_info=source_info, print_shapes=print_shapes,
custom_pp_eqn_rules=custom_pp_eqn_rules, name_stack=name_stack,
print_effects=print_effects)
return doc.format(**kwargs)

def _repr_pretty_(self, p, cycle):
return p.text(self.pretty_print(use_color=True))
Expand All @@ -212,6 +172,7 @@ def replace(self, **kwargs):
return jaxpr



def join_effects(*effects: Effects) -> Effects:
return set().union(*effects) if effects else no_effects

Expand Down Expand Up @@ -3164,6 +3125,55 @@ def _check_map(ctx_factory, prim, in_avals, params):

# ------------------- Jaxpr printed representation -------------------

def pp_toplevel_jaxpr(jaxpr_to_print, *, source_info=False, print_shapes=True,
custom_pp_eqn_rules=True, name_stack=False,
print_effects: bool = False) -> pp.Doc:
context = JaxprPpContext()
settings = JaxprPpSettings(
source_info=source_info,
print_shapes=print_shapes,
custom_pp_eqn_rules=custom_pp_eqn_rules,
name_stack=name_stack,
print_effects=print_effects)

# Compute how many times each jaxpr is used.
names = defaultdict[Jaxpr, str](lambda: "jaxpr")
jaxpr_counts = Counter[Jaxpr]()
s = deque([jaxpr_to_print])
while s:
jaxpr = s.popleft()
jaxpr_counts[jaxpr] += 1
for eqn in jaxpr.eqns:
# TODO(slebedev): Come up with a more elaborate heuristic for name=.
name = eqn.params.get("name")
if name is None:
s.extend(jaxprs_in_params(eqn.params))
continue
name = name.strip("<>") # <lambda> -> lambda
for subjaxpr in jaxprs_in_params(eqn.params):
s.append(subjaxpr)
names.setdefault(subjaxpr, name)

# Pull jaxprs occurring more than once to the top-level, making sure
# that their names are unique.
docs = []
name_counts = Counter[str]()
for jaxpr, c in jaxpr_counts.items():
if c == 1:
continue
name = names[jaxpr]
if (count := name_counts[name]) > 0:
name_counts[name] += 1
name += str(count)
name_counts[name] += 1
else:
name_counts[name] += 1
docs.append(pp_top_level_jaxpr(name, jaxpr, context, settings))
context.used_names.add(name)
context.top_level_jaxprs[jaxpr] = name
docs.append(pp_jaxpr(jaxpr_to_print, context, settings))
return pp.concat(docs)


class JaxprPpSettings(NamedTuple):
print_shapes: bool = True
Expand Down Expand Up @@ -3253,7 +3263,9 @@ def pp_eqn(eqn: JaxprEqn, context: JaxprPpContext, settings: JaxprPpSettings
) -> pp.Doc:
rule = (_pp_eqn if not settings.custom_pp_eqn_rules else
pp_eqn_rules.get(eqn.primitive, _pp_eqn))
return rule(eqn, context, settings) # type: ignore[operator]
doc = rule(eqn, context, settings) # type: ignore[operator]
user_frame = source_info_util.user_frame(eqn.source_info)
return doc if user_frame is None else pp.source_map(doc, user_frame)

def _pp_eqn(eqn, context, settings, params=None) -> pp.Doc:
annotation = (source_info_util.summarize(eqn.source_info)
Expand Down
92 changes: 78 additions & 14 deletions jax/_src/pretty_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
import enum
from functools import partial
import sys
from typing import NamedTuple
from typing import Any, NamedTuple

from jax._src import config
from jax._src import util
Expand Down Expand Up @@ -69,12 +69,23 @@ def _can_use_color() -> bool:
class Doc(util.StrictABC):
__slots__ = ()

def format(self, width: int = 80, use_color: bool | None = None,
annotation_prefix=" # ") -> str:
def format(
self, width: int = 80, *, use_color: bool | None = None,
annotation_prefix: str = " # ",
source_map: list[list[tuple[int, int, Any]]] | None = None
) -> str:
"""
Formats a pretty-printer document as a string.
Args:
source_map: for each line in the output, contains a list of
(start column, end column, source) tuples. Each tuple associates a
region of output text with a source.
"""
if use_color is None:
use_color = CAN_USE_COLOR and _PPRINT_USE_COLOR.value
return _format(self, width, use_color=use_color,
annotation_prefix=annotation_prefix)
annotation_prefix=annotation_prefix, source_map=source_map)

def __str__(self):
return self.format()
Expand Down Expand Up @@ -147,6 +158,21 @@ def __init__(self, n: int, child: Doc):
def __repr__(self): return f"nest({self.n, self.child})"


_NO_SOURCE = object()

class _SourceMapDoc(Doc):
__slots__ = ("child", "source")
child: Doc
source: Any

def __init__(self, child: Doc, source: Any):
assert isinstance(child, Doc), child
self.child = child
self.source = source

def __repr__(self): return f"source({self.child}, {self.source})"


Color = enum.Enum("_Color", ["BLACK", "RED", "GREEN", "YELLOW", "BLUE",
"MAGENTA", "CYAN", "WHITE", "RESET"])
Intensity = enum.Enum("_Intensity", ["DIM", "NORMAL", "BRIGHT"])
Expand Down Expand Up @@ -193,7 +219,7 @@ def _fits(doc: Doc, width: int, agenda: list[tuple[int, _BreakMode, Doc]]
agenda.append((i + doc.n, m, doc.child))
elif isinstance(doc, _GroupDoc):
agenda.append((i, _BreakMode.FLAT, doc.child))
elif isinstance(doc, _ColorDoc):
elif isinstance(doc, _ColorDoc) or isinstance(doc, _SourceMapDoc):
agenda.append((i, m, doc.child))
else:
raise ValueError("Invalid document ", doc)
Expand Down Expand Up @@ -224,7 +250,7 @@ def _sparse(doc: Doc) -> bool:
agenda.append(doc.child)
elif isinstance(doc, _GroupDoc):
agenda.append(doc.child)
elif isinstance(doc, _ColorDoc):
elif isinstance(doc, _ColorDoc) or isinstance(doc, _SourceMapDoc):
agenda.append(doc.child)
else:
raise ValueError("Invalid document ", doc)
Expand All @@ -241,6 +267,7 @@ class _State(NamedTuple):
mode: _BreakMode
doc: Doc
color: _ColorState
source_map: Any

class _Line(NamedTuple):
text: str
Expand Down Expand Up @@ -283,17 +310,29 @@ def _align_annotations(lines):



def _format(doc: Doc, width: int, *, use_color, annotation_prefix) -> str:
def _format(
doc: Doc, width: int, *, use_color: bool, annotation_prefix: str,
source_map: list[list[tuple[int, int, Any]]] | None
) -> str:
lines = []
default_colors = _ColorState(Color.RESET, Color.RESET, Intensity.NORMAL)
annotation_colors = _ColorState(Color.RESET, Color.RESET, Intensity.DIM)
color_state = default_colors
agenda = [_State(0, _BreakMode.BREAK, doc, default_colors)]
source_start = 0 # The column at which the current source region starts.
source = _NO_SOURCE # The currently active source region.
line_source_map = [] # Source maps for the current line of text.
agenda = [_State(0, _BreakMode.BREAK, doc, default_colors, source)]
k = 0
line_text = ""
line_annotations = []
while len(agenda) > 0:
i, m, doc, color = agenda.pop()
i, m, doc, color, agenda_source = agenda.pop()
if source_map is not None and agenda_source != source:
pos = len(line_text)
if source_start != pos and source is not _NO_SOURCE:
line_source_map.append((source_start, pos, source))
source = agenda_source
source_start = pos
if isinstance(doc, _NilDoc):
pass
elif isinstance(doc, _TextDoc):
Expand All @@ -304,7 +343,7 @@ def _format(doc: Doc, width: int, *, use_color, annotation_prefix) -> str:
line_annotations.append(doc.annotation)
k += len(doc.text)
elif isinstance(doc, _ConcatDoc):
agenda.extend(_State(i, m, d, color)
agenda.extend(_State(i, m, d, color, source)
for d in reversed(doc.children))
elif isinstance(doc, _BreakDoc):
if m == _BreakMode.BREAK:
Expand All @@ -313,6 +352,13 @@ def _format(doc: Doc, width: int, *, use_color, annotation_prefix) -> str:
annotation_colors)
line_text += color_str
lines.append(_Line(line_text, k, line_annotations))
if source_map is not None:
pos = len(line_text)
if source_start != pos and source is not _NO_SOURCE:
line_source_map.append((source_start, pos, source))
source_map.append(line_source_map)
line_source_map = []
source_start = i
line_text = " " * i
line_annotations = []
k = i
Expand All @@ -322,27 +368,34 @@ def _format(doc: Doc, width: int, *, use_color, annotation_prefix) -> str:
line_text += doc.text
k += len(doc.text)
elif isinstance(doc, _NestDoc):
agenda.append(_State(i + doc.n, m, doc.child, color))
agenda.append(_State(i + doc.n, m, doc.child, color, source))
elif isinstance(doc, _GroupDoc):
# In Lindig's paper, _fits is passed the remainder of the document.
# I'm pretty sure that's a bug and we care only if the current group fits!
if (_sparse(doc)
and _fits(doc, width - k, [(i, _BreakMode.FLAT, doc.child)])):
agenda.append(_State(i, _BreakMode.FLAT, doc.child, color))
agenda.append(_State(i, _BreakMode.FLAT, doc.child, color, source))
else:
agenda.append(_State(i, _BreakMode.BREAK, doc.child, color))
agenda.append(_State(i, _BreakMode.BREAK, doc.child, color, source))
elif isinstance(doc, _ColorDoc):
color = _ColorState(doc.foreground or color.foreground,
doc.background or color.background,
doc.intensity or color.intensity)
agenda.append(_State(i, m, doc.child, color))
agenda.append(_State(i, m, doc.child, color, source))
elif isinstance(doc, _SourceMapDoc):
agenda.append(_State(i, m, doc.child, color, doc.source))
else:
raise ValueError("Invalid document ", doc)

if len(line_annotations) > 0:
color_state, color_str = _update_color(use_color, color_state,
annotation_colors)
line_text += color_str
if source_map is not None:
pos = len(line_text)
if source_start != pos and source is not _NO_SOURCE:
line_source_map.append((source_start, pos, source))
source_map.append(line_source_map)
lines.append(_Line(line_text, k, line_annotations))
lines = _align_annotations(lines)
out = "\n".join(
Expand Down Expand Up @@ -406,6 +459,17 @@ def color(doc: Doc, *, foreground: Color | None = None,
intensity=intensity)


def source_map(doc: Doc, source: Any):
"""Source mapping.
A source map associates a region of the pretty-printer's text output with a
source location that produced it. For the purposes of the pretty printer a
``source`` may be any object: we require only that we can compare sources for
equality. A text region to source object mapping can be populated as a side
output of the ``format`` method.
"""
return _SourceMapDoc(doc, source)

type_annotation = partial(color, intensity=Intensity.NORMAL,
foreground=Color.MAGENTA)
keyword = partial(color, intensity=Intensity.BRIGHT, foreground=Color.BLUE)
Expand Down
9 changes: 9 additions & 0 deletions tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1477,6 +1477,15 @@ jax_test(
],
)

py_test(
name = "pretty_printer_test",
srcs = ["pretty_printer_test.py"],
deps = [
"//jax",
"//jax:test_util",
],
)

exports_files(
[
"api_test.py",
Expand Down
36 changes: 36 additions & 0 deletions tests/pretty_printer_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from absl.testing import absltest

from jax._src import test_util as jtu
from jax._src import pretty_printer as pp


class PrettyPrinterTest(jtu.JaxTestCase):

def testSourceMap(self):
doc = pp.concat([
pp.text("abc"), pp.source_map(pp.text("def"), 101),
pp.source_map(pp.concat([pp.text("gh"), pp.brk(""), pp.text("ijkl")]), 77),
pp.text("mn"),
])
source_map = []
out = doc.format(width=8, source_map=source_map)
self.assertEqual(out, "abcdefgh\nijklmn")
self.assertEqual(source_map, [[(3, 6, 101), (6, 8, 77)], [(0, 4, 77)]])


if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit d014f5d

Please sign in to comment.