diff --git a/hippunfold_plot/plotting.py b/hippunfold_plot/plotting.py index 75ea4d3..b3238c6 100644 --- a/hippunfold_plot/plotting.py +++ b/hippunfold_plot/plotting.py @@ -1,8 +1,27 @@ import matplotlib.pyplot as plt from nilearn.plotting import plot_surf from hippunfold_plot.utils import get_surf_limits, get_data_limits, get_resource_path, check_surf_map_is_label_gii, get_legend_elements_from_label_gii - -def plot_hipp_surf(surf_map, density='0p5mm', hemi='left', space=None, figsize=(12, 8), dpi=300, vmin=None, vmax=None, colorbar=False, colorbar_shrink=0.25, cmap=None, view='dorsal', avg_method='median', bg_on_data=True, alpha=0.1, darkness=2, **kwargs): +from typing import Union, Tuple, Optional, List + +def plot_hipp_surf(surf_map: Union[str, list], + density: str = '0p5mm', + hemi: str = 'left', + space: Optional[str] = None, + figsize: Tuple[int, int] = (12, 8), + dpi: int = 300, + vmin: Optional[float] = None, + vmax: Optional[float] = None, + colorbar: bool = False, + colorbar_shrink: float = 0.25, + cmap: Optional[Union[str, plt.cm.ScalarMappable]] = None, + view: str = 'dorsal', + avg_method: str = 'median', + bg_on_data: bool = True, + alpha: float = 0.1, + darkness: float = 2, + axes: Optional[Union[plt.Axes, List[plt.Axes]]] = None, + figure: Optional[plt.Figure] = None, + **kwargs) -> plt.Figure: """Plot hippocampal surface map. This function plots a surface map of the hippocampus, which can be a label-hippdentate shape.gii, func.gii, or a Vx1 array @@ -43,6 +62,11 @@ def plot_hipp_surf(surf_map, density='0p5mm', hemi='left', space=None, figsize=( The alpha transparency level. Default is 0.1. darkness : float, optional The darkness level of the background. Default is 2. + axes : matplotlib.axes.Axes or list of matplotlib.axes.Axes, optional + Axes to plot on. If None, new axes will be created. If a single axis is provided, it will be used for a single plot. + If multiple plots are to be made, a list of axes should be provided. + figure : matplotlib.figure.Figure, optional + The figure to plot on. If None, a new figure will be created. **kwargs : dict Additional arguments to pass to nilearn's plot_surf(). @@ -87,19 +111,6 @@ def plot_hipp_surf(surf_map, density='0p5mm', hemi='left', space=None, figsize=( #add any user arguments plot_kwargs.update(kwargs) - - # Create a figure - fig = plt.figure(figsize=figsize,dpi=dpi) # Adjust figure size for tall axes - - # Define positions for 4 tall side-by-side axes - positions = [ - [0.05, 0.1, 0.2, 0.8], # Left, bottom, width, height - [0.18, 0.1, 0.2, 0.8], - [0.30, 0.1, 0.2, 0.8], - [0.43, 0.1, 0.2, 0.8], - [0.55, 0.1, 0.2, 0.8], - - ] # Define the plotting order for each hemisphere hemi_space_map = { @@ -107,17 +118,39 @@ def plot_hipp_surf(surf_map, density='0p5mm', hemi='left', space=None, figsize=( 'right': ['canonical', 'unfold'] } - + # Determine the number of plots to be made + hemis_to_plot = [hemi] if hemi else hemi_space_map.keys() + num_plots = sum(len([space] if space else hemi_space_map[h]) for h in hemis_to_plot) + + # Validate axes input + if axes is not None: + if isinstance(axes, plt.Axes): + if num_plots > 1: + raise ValueError("Multiple plots requested, but only one axis provided.") + axes = [axes] + elif isinstance(axes, list): + if len(axes) != num_plots: + raise ValueError(f"Expected {num_plots} axes, but got {len(axes)}.") + else: + raise ValueError("Invalid type for 'axes'. Expected matplotlib.axes.Axes or list of matplotlib.axes.Axes.") + + + # Create a figure if not provided + if fig is None: + fig = plt.figure(figsize=figsize, dpi=dpi) + + # Create axes if not provided + if axes is None: + axes = [fig.add_subplot(1, num_plots, i + 1, projection='3d') for i in range(num_plots)] + + pos=0 # Build the composite plot - hemis_to_plot = [hemi] if hemi else hemi_space_map.keys() for h in hemis_to_plot: spaces_to_plot = [space] if space else hemi_space_map[h] for s in spaces_to_plot: - - ax = fig.add_axes(positions[pos], projection='3d') # Add 3D axes - + ax = axes[pos] plot_surf(surf_mesh=surf_gii.format(hemi=h,space=s,density=density), axes=ax, figure=fig,