Skip to content

Commit

Permalink
fix #1260: include points in plotting limits
Browse files Browse the repository at this point in the history
  • Loading branch information
janfb committed Sep 4, 2024
1 parent bd740d6 commit c45dbae
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 33 deletions.
58 changes: 36 additions & 22 deletions sbi/analysis/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,42 +554,56 @@ def handle_nan_infs(samples: List[np.ndarray]) -> List[np.ndarray]:
return samples


def convert_to_list_of_numpy(
arr: Union[List[np.ndarray], List[torch.Tensor], np.ndarray, torch.Tensor],
) -> List[np.ndarray]:
"""Converts a list of torch.Tensor to a list of np.ndarray."""
if not isinstance(arr, list):
arr = ensure_numpy(arr)
return [arr]
return [ensure_numpy(a) for a in arr]


def infer_limits(
samples: List[np.ndarray], dim: int, points: Optional[List[np.ndarray]] = None
) -> List:
"""Infer limits for the plot."""
limits = []
for d in range(dim):
min_val = min(np.min(sample[:, d]) for sample in samples)
max_val = max(np.max(sample[:, d]) for sample in samples)
if points is not None:
min_val = min(min_val, min(np.min(point[:, d]) for point in points))
max_val = max(max_val, max(np.max(point[:, d]) for point in points))

Check warning on line 577 in sbi/analysis/plot.py

View check run for this annotation

Codecov / codecov/patch

sbi/analysis/plot.py#L576-L577

Added lines #L576 - L577 were not covered by tests
limits.append([min_val, max_val])
return limits


def prepare_for_plot(
samples: Union[List[np.ndarray], List[torch.Tensor], np.ndarray, torch.Tensor],
limits: Optional[Union[List, torch.Tensor, np.ndarray]],
limits: Optional[Union[List, torch.Tensor, np.ndarray]] = None,
points: Optional[
Union[List[np.ndarray], List[torch.Tensor], np.ndarray, torch.Tensor]
] = None,
) -> Tuple[List[np.ndarray], int, torch.Tensor]:
"""
Ensures correct formatting for samples and limits, and returns dimension
of the samples.
"""

# Prepare samples
if not isinstance(samples, list):
samples = ensure_numpy(samples)
samples = [samples]
else:
samples = [ensure_numpy(sample) for sample in samples]
samples = convert_to_list_of_numpy(samples)
if points is not None:
points = convert_to_list_of_numpy(points)

# check if nans and infs
samples = handle_nan_infs(samples)

# Dimensionality of the problem.
dim = samples[0].shape[1]

# Prepare limits. Infer them from samples if they had not been passed.
if limits == [] or limits is None:
limits = []
for d in range(dim):
min = +np.inf
max = -np.inf
for sample in samples:
min_ = np.min(sample[:, d])
min = min_ if min_ < min else min
max_ = np.max(sample[:, d])
max = max_ if max_ > max else max
limits.append([min, max])
if limits is None or limits == []:
limits = infer_limits(samples, dim, points)
else:
limits = [limits[0] for _ in range(dim)] if len(limits) == 1 else limits

limits = torch.as_tensor(limits)
return samples, dim, limits

Expand Down Expand Up @@ -737,7 +751,7 @@ def pairplot(
)
return fig, axes

samples, dim, limits = prepare_for_plot(samples, limits)
samples, dim, limits = prepare_for_plot(samples, limits, points)

# prepate figure kwargs
fig_kwargs_filled = _get_default_fig_kwargs()
Expand Down
Loading

0 comments on commit c45dbae

Please sign in to comment.