Skip to content

Commit

Permalink
[vis] Improved Matplotlib annotation positions
Browse files Browse the repository at this point in the history
  • Loading branch information
Philipp Holl committed Oct 7, 2024
1 parent a72de92 commit a8c0d88
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions phi/vis/_matplotlib/_matplotlib_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,21 +776,21 @@ def _annotate_points(axis, points: math.Tensor, color: Tensor, alpha: Tensor, di
set_ticks(axis, which_axis, reshaped_numpy(points.vector[labeled_dim.name], [shape]))
return # The point labels match one of the figure axes, so they are redundant
if points.shape['vector'].size == 2:
xs, ys = reshaped_numpy(points, ['vector', points.shape.without('vector')])
np_points = points.numpy([..., 'vector'])
rel_pos = axis.transAxes.inverted().transform(axis.transData.transform(np_points))
x_view = axis.get_xlim()[1] - axis.get_xlim()[0]
y_view = axis.get_ylim()[1] - axis.get_ylim()[0]
x_c = .95 * axis.get_xlim()[1] + .1 * axis.get_xlim()[0]
y_c = .95 * axis.get_ylim()[1] + .1 * axis.get_ylim()[0]
for x, y, idx, idx_n in zip(xs, ys, labeled_dims.meshgrid(), labeled_dims.meshgrid(names=True)):
for (x, y), (rx, ry), idx, idx_n in zip(np_points, rel_pos, labeled_dims.meshgrid(), labeled_dims.meshgrid(names=True)):
horizontal_align = 'right' if rx >= .5 else 'left'
if axis.get_xscale() == 'log':
offset_x = x * (1 + .0003 * x_view) if x < x_c else x * (1 - .0003 * x_view)
offset_x = x * (1 + .0003 * x_view) if rx < .5 else x * (1 - .0003 * x_view)
else:
offset_x = x + .11 * x_view if x < x_c else x - .26 * x_view
offset_x = x + .01 * x_view if rx < .5 else x - .01 * x_view
if axis.get_yscale() == 'log':
offset_y = y * (1 + .0003 * y_view) if y < y_c else y * (1 - .0003 * y_view)
offset_y = y * (1 + .0003 * y_view) if ry < .5 else y * (1 - .0003 * y_view)
else:
offset_y = y + .01 * y_view if y < y_c else y - .01 * y_view
axis.text(offset_x, offset_y, index_label(idx_n), color=_plt_col(color[idx]), alpha=float(alpha[idx]))
offset_y = y + .01 * y_view if ry < .5 else y - .01 * y_view
axis.text(offset_x, offset_y, index_label(idx_n), color=_plt_col(color[idx]), alpha=float(alpha[idx]), ha=horizontal_align)


class PointCloud3D(Recipe):
Expand Down

0 comments on commit a8c0d88

Please sign in to comment.