Skip to content

Commit

Permalink
Merge pull request #43 from SarthakJariwala/v0.4.x
Browse files Browse the repository at this point in the history
Stricter separation between axes and fig level functions
  • Loading branch information
SarthakJariwala authored Sep 6, 2020
2 parents 01154c0 + e6b2222 commit 349b620
Show file tree
Hide file tree
Showing 8 changed files with 107 additions and 92 deletions.
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,23 @@

<div class="row">

<a href="https://seaborn-image.readthedocs.io/auto_examples/plot_image_hist.html">
<a href="https://seaborn-image.readthedocs.io/en/latest/auto_examples/plot_image_hist.html">
<img src="./images/sphx_glr_plot_image_hist_001.png" height="120" width="170">
</a>

<a href="https://seaborn-image.readthedocs.io/auto_examples/plot_filter.html">
<a href="https://seaborn-image.readthedocs.io/en/latest/auto_examples/plot_filter.html">
<img src="./images/sphx_glr_plot_filter_001.png" height="120" width="130">
</a>

<a href="https://seaborn-image.readthedocs.io/auto_examples/plot_fft.html">
<a href="https://seaborn-image.readthedocs.io/en/latest/auto_examples/plot_fft.html">
<img src="./images/sphx_glr_plot_fft_001.png" height="120" width="120">
</a>

<a href="https://seaborn-image.readthedocs.io/auto_examples/plot_filtergrid.html">
<a href="https://seaborn-image.readthedocs.io/en/latest/auto_examples/plot_filtergrid.html">
<img src="./images/sphx_glr_plot_filtergrid_001.png" height="120" width="120">
</a>

<a href="https://seaborn-image.readthedocs.io/auto_examples/plot_image_robust.html">
<a href="https://seaborn-image.readthedocs.io/en/latest/auto_examples/plot_image_robust.html">
<img src="./images/sphx_glr_plot_image_robust_001.png" height="120" width="260">
</a>

Expand Down
27 changes: 12 additions & 15 deletions src/seaborn_image/_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,20 +118,17 @@ def filterplot(
Returns
-------
(tuple): tuple containing:
(`matplotlib.axes.Axes`): Matplotlib axes where the image is drawn.
(`matplotlib.axes.Axes`): Colorbar axes
(`numpy.array`): Filtered image data
`matplotlib.axes.Axes`
Matplotlib axes where the image is drawn.
Raises
------
TypeError
if `filt` is not a string type or callable function
NotImplementedError
if a `filt` that is not implemented is specified
TypeError
if `describe` is not a `bool`
TypeError
if `filt` is not a string type or callable function
NotImplementedError
if a `filt` that is not implemented is specified
TypeError
if `describe` is not a `bool`
Examples
--------
Expand Down Expand Up @@ -193,7 +190,7 @@ def filterplot(
filtered_data = filt_func(data, **func_kwargs)

# finally, plot the filtered image
f, ax, cax = imgplot(
ax = imgplot(
filtered_data,
ax=ax,
cmap=cmap,
Expand Down Expand Up @@ -224,7 +221,7 @@ def filterplot(
print(f"Variance : {result_1.variance}")
print(f"Skewness : {result_1.skewness}")

return ax, cax, filtered_data
return ax


def fftplot(
Expand All @@ -250,7 +247,7 @@ def fftplot(
# perform fft
data_f_mag = fftshift(np.abs(fftn(w_data)))

f, ax, cax = imgplot(
ax = imgplot(
np.log(data_f_mag),
ax=ax,
cmap=cmap,
Expand All @@ -261,4 +258,4 @@ def fftplot(
despine=despine,
)

return ax, cax
return ax
21 changes: 7 additions & 14 deletions src/seaborn_image/_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,8 @@ def imgplot(
Returns
-------
`matplotlib.figure.Figure`
Matplotlib figure.
`matplotlib.axes.Axes`
Matplotlib axes where the image is drawn.
`matplotlib.axes.Axes`
Colorbar axes
Raises
------
Expand Down Expand Up @@ -279,10 +275,10 @@ def imgplot(
print(f"Variance : {result.variance}")
print(f"Skewness : {result.skewness}")

return f, ax, cax
return ax


# TODO implement a imgdist function with more distributions
# TODO implement a imgdist function with more distributions (?)
# TODO add height, aspect parameter
def imghist(
data,
Expand Down Expand Up @@ -367,12 +363,6 @@ def imghist(
-------
`matplotlib.figure.Figure`
Matplotlib figure.
`matplotlib.axes.Axes`
Matplotlib axes where the image is drawn.
`matplotlib.axes.Axes`
Matplotlib axes where the histogram is drawn.
`matplotlib.axes.Axes`
Colorbar axes
Raises
------
Expand Down Expand Up @@ -438,7 +428,7 @@ def imghist(

ax1 = f.add_subplot(gs[0])

f, ax1, cax = imgplot(
ax1 = imgplot(
data,
ax=ax1,
cmap=cmap,
Expand All @@ -458,6 +448,9 @@ def imghist(
despine=despine,
)

# get colorbar axes
cax = f.axes[1]

if orientation == "vertical":
ax2 = f.add_subplot(gs[1], sharey=cax)

Expand Down Expand Up @@ -496,4 +489,4 @@ def imghist(
for c, p in zip(col, patches):
plt.setp(p, "facecolor", cm(c))

return f, (ax1, ax2), cax
return f
6 changes: 4 additions & 2 deletions src/seaborn_image/_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,8 +337,10 @@ def map_filter_to_grid(self):
def _plot(self, ax, **func_kwargs):
"""Helper function to call the underlying filterplot
Args:
ax (`matplotlib.axes.Axes`): Axis to plot filtered image
Parameters
----------
ax : `matplotlib.axes.Axes`
Axis to plot filtered image
"""

filterplot(
Expand Down
8 changes: 6 additions & 2 deletions src/seaborn_image/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ def scientific_ticks(ax, which="y"):
>>> import seaborn_image as isns
>>> img = isns.load_image("polymer") * 1e-9
>>> f, ax, cax = isns.imgplot(img)
>>> ax = isns.imgplot(img)
>>> # get colorbar axes
>>> cax = plt.gcf().axes[1]
>>> isns.scientific_ticks(cax)
Set colorbar xaxis ticks to scientific
Expand All @@ -43,7 +45,9 @@ def scientific_ticks(ax, which="y"):
>>> import seaborn_image as isns
>>> img = isns.load_image("polymer") * 1e-9
>>> f, ax, cax = isns.imgplot(img, orientation="h")
>>> ax = isns.imgplot(img, orientation="h")
>>> # get colorbar axes
>>> cax = plt.gcf().axes[1]
>>> isns.scientific_ticks(cax, which="x")
"""

Expand Down
84 changes: 49 additions & 35 deletions tests/test_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,147 +34,161 @@ def test_describe_type(describe):
isns.filterplot(data, "sobel", describe=describe)


# @pytest.mark.parametrize("filt", isns.implemented_filters)
@pytest.mark.parametrize("describe", [True, False])
def test_filters(describe):
ax, cax, filt_data = isns.filterplot(data, "sobel", describe=describe)
def test_filterplot_describe(describe):
ax = isns.filterplot(data, "sobel", describe=describe)

assert isinstance(ax, Axes)
assert isinstance(cax, Axes)

plt.close("all")


def test_filterplot_callable_filt():
"Test a callable filt parameter with additional parameters passed to the callable filt function"
_, _, filt_data = isns.filterplot(data, ndi.uniform_filter, size=5, mode="nearest")
ax = isns.filterplot(data, ndi.uniform_filter, size=5, mode="nearest")

np.testing.assert_array_equal(
filt_data, ndi.uniform_filter(data, size=5, mode="nearest")
ax.images[0].get_array().data, ndi.uniform_filter(data, size=5, mode="nearest")
)

plt.close("all")


def test_filterplot_gaussian():
_, _, filt_data = isns.filterplot(data, filt="gaussian", sigma=1)
ax = isns.filterplot(data, filt="gaussian", sigma=1)

np.testing.assert_array_equal(filt_data, ndi.gaussian_filter(data, sigma=1))
np.testing.assert_array_equal(
ax.images[0].get_array().data, ndi.gaussian_filter(data, sigma=1)
)

plt.close("all")


def test_filterplot_sobel():
_, _, filt_data = isns.filterplot(data, filt="sobel")
ax = isns.filterplot(data, filt="sobel")

np.testing.assert_array_equal(filt_data, ndi.sobel(data))
np.testing.assert_array_equal(ax.images[0].get_array().data, ndi.sobel(data))

plt.close("all")


def test_filterplot_median():
_, _, filt_data = isns.filterplot(data, filt="median", size=5)
ax = isns.filterplot(data, filt="median", size=5)

np.testing.assert_array_equal(filt_data, ndi.median_filter(data, size=5))
np.testing.assert_array_equal(
ax.images[0].get_array().data, ndi.median_filter(data, size=5)
)

plt.close("all")


def test_filterplot_max():
_, _, filt_data = isns.filterplot(data, filt="max", size=5)
ax = isns.filterplot(data, filt="max", size=5)

np.testing.assert_array_equal(filt_data, ndi.maximum_filter(data, size=5))
np.testing.assert_array_equal(
ax.images[0].get_array().data, ndi.maximum_filter(data, size=5)
)

plt.close("all")


def test_filterplot_diff_of_gaussian():
_, _, filt_data = isns.filterplot(data, filt="diff_of_gaussians", low_sigma=1)
ax = isns.filterplot(data, filt="diff_of_gaussians", low_sigma=1)

np.testing.assert_array_equal(filt_data, difference_of_gaussians(data, low_sigma=1))
np.testing.assert_array_equal(
ax.images[0].get_array().data, difference_of_gaussians(data, low_sigma=1)
)

plt.close("all")


def test_filterplot_gaussian_gradient_magnitude():
_, _, filt_data = isns.filterplot(data, filt="gaussian_gradient_magnitude", sigma=1)
ax = isns.filterplot(data, filt="gaussian_gradient_magnitude", sigma=1)

np.testing.assert_array_equal(
filt_data, ndi.gaussian_gradient_magnitude(data, sigma=1)
ax.images[0].get_array().data, ndi.gaussian_gradient_magnitude(data, sigma=1)
)

plt.close("all")


def test_filterplot_gaussian_laplace():
_, _, filt_data = isns.filterplot(data, filt="gaussian_laplace", sigma=1)
ax = isns.filterplot(data, filt="gaussian_laplace", sigma=1)

np.testing.assert_array_equal(filt_data, ndi.gaussian_laplace(data, sigma=1))
np.testing.assert_array_equal(
ax.images[0].get_array().data, ndi.gaussian_laplace(data, sigma=1)
)

plt.close("all")


def test_filterplot_laplace():
_, _, filt_data = isns.filterplot(data, filt="laplace")
ax = isns.filterplot(data, filt="laplace")

np.testing.assert_array_equal(filt_data, ndi.laplace(data))
np.testing.assert_array_equal(ax.images[0].get_array().data, ndi.laplace(data))

plt.close("all")


def test_filterplot_min():
_, _, filt_data = isns.filterplot(data, filt="min", size=5)
ax = isns.filterplot(data, filt="min", size=5)

np.testing.assert_array_equal(filt_data, ndi.minimum_filter(data, size=5))
np.testing.assert_array_equal(
ax.images[0].get_array().data, ndi.minimum_filter(data, size=5)
)

plt.close("all")


def test_filterplot_percentile():
_, _, filt_data = isns.filterplot(data, filt="percentile", percentile=10, size=10)
ax = isns.filterplot(data, filt="percentile", percentile=10, size=10)

np.testing.assert_array_equal(
filt_data, ndi.percentile_filter(data, percentile=10, size=10)
ax.images[0].get_array().data,
ndi.percentile_filter(data, percentile=10, size=10),
)

plt.close("all")


def test_filterplot_prewitt():
_, _, filt_data = isns.filterplot(data, filt="prewitt")
ax = isns.filterplot(data, filt="prewitt")

np.testing.assert_array_equal(filt_data, ndi.prewitt(data))
np.testing.assert_array_equal(ax.images[0].get_array().data, ndi.prewitt(data))

plt.close("all")


def test_filterplot_rank():
_, _, filt_data = isns.filterplot(data, filt="rank", rank=1, size=10)
ax = isns.filterplot(data, filt="rank", rank=1, size=10)

np.testing.assert_array_equal(filt_data, ndi.rank_filter(data, rank=1, size=10))
np.testing.assert_array_equal(
ax.images[0].get_array().data, ndi.rank_filter(data, rank=1, size=10)
)

plt.close("all")


def test_filterplot_uniform():
_, _, filt_data = isns.filterplot(data, filt="uniform")
ax = isns.filterplot(data, filt="uniform")

np.testing.assert_array_equal(filt_data, ndi.uniform_filter(data))
np.testing.assert_array_equal(
ax.images[0].get_array().data, ndi.uniform_filter(data)
)

plt.close("all")


def test_fftplot_plot():
ax, cax = isns.fftplot(data)
ax = isns.fftplot(data)

assert isinstance(ax, Axes)
assert isinstance(cax, Axes)

plt.close("all")


def test_fftplot_fft():
ax, cax = isns.fftplot(data)
ax = isns.fftplot(data)

w_data = data * window("hann", data.shape)
data_f_mag = fftshift(np.abs(fftn(w_data)))
Expand Down
Loading

0 comments on commit 349b620

Please sign in to comment.