Skip to content

Commit

Permalink
implementing starMap refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
Shannon Axelrod committed Sep 27, 2019
1 parent 89a13f2 commit f2148b4
Show file tree
Hide file tree
Showing 15 changed files with 501 additions and 76 deletions.
27 changes: 15 additions & 12 deletions notebooks/STARmap.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
"import starfish\n",
"from starfish import IntensityTable\n",
"import starfish.data\n",
"from starfish.types import Axes\n",
"from starfish.types import Axes, TraceBuildingStrategies\n",
"from starfish.util.plot import (\n",
" diagnose_registration, imshow_plane, intensity_histogram\n",
")\n",
Expand Down Expand Up @@ -302,16 +302,20 @@
"metadata": {},
"outputs": [],
"source": [
"lsbd = starfish.spots.DetectSpots.LocalSearchBlobDetector(\n",
" min_sigma=1,\n",
" max_sigma=8,\n",
" num_sigma=10,\n",
" threshold=np.percentile(np.ravel(stack.xarray.values), 95),\n",
" exclude_border=2,\n",
" anchor_round=0,\n",
" search_radius=10,\n",
")\n",
"intensities = lsbd.run(scaled, n_processes=8)"
"bd = starfish.spots.FindSpots.BlobDetector(min_sigma=1,\n",
" max_sigma=8,\n",
" num_sigma=10,\n",
" threshold=np.percentile(np.ravel(stack.xarray.values), 95),\n",
" exclude_border=2)\n",
"\n",
"spots = bd.run(scaled, n_processes=8)\n",
"decoder = starfish.spots.DecodeSpots.PerRoundMaxChannel(codebook=experiment.codebook,\n",
" anchor_round=0,\n",
" search_radius=10,\n",
" trace_building_strategy=\n",
" TraceBuildingStrategies.NEAREST_NEIGHBOR)\n",
"\n",
"decoded = decoder.run(spots=spots)"
]
},
{
Expand All @@ -331,7 +335,6 @@
"metadata": {},
"outputs": [],
"source": [
"decoded = experiment.codebook.decode_per_round_max(IntensityTable(intensities.fillna(0)))\n",
"decode_mask = decoded['target'] != 'nan'\n",
"\n",
"# %gui qt\n",
Expand Down
26 changes: 14 additions & 12 deletions notebooks/SeqFISH.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@
"\n",
"import starfish\n",
"import starfish.data\n",
"from starfish.spots import DetectSpots\n",
"from starfish.types import Axes"
"from starfish.types import Axes, TraceBuildingStrategies"
]
},
{
Expand Down Expand Up @@ -196,15 +195,18 @@
"source": [
"threshold = 0.5\n",
"\n",
"lsbd = starfish.spots.DetectSpots.LocalSearchBlobDetector(\n",
" min_sigma=(1.5, 1.5, 1.5),\n",
" max_sigma=(8, 8, 8),\n",
" num_sigma=10,\n",
" threshold=threshold,\n",
" search_radius=7\n",
")\n",
"intensities = lsbd.run(clipped)\n",
"decoded = exp.codebook.decode_per_round_max(intensities.fillna(0))"
"bd = starfish.spots.FindSpots.BlobDetector(min_sigma=(1.5, 1.5, 1.5),\n",
" max_sigma=(8, 8, 8),\n",
" num_sigma=10,\n",
" threshold=threshold)\n",
"\n",
"spots = bd.run(clipped)\n",
"decoder = starfish.spots.DecodeSpots.PerRoundMaxChannel(codebook=exp.codebook,\n",
" search_radius=7,\n",
" trace_building_strategy=\n",
" TraceBuildingStrategies.NEAREST_NEIGHBOR)\n",
"\n",
"decoded = decoder.run(spots=spots)"
]
},
{
Expand All @@ -213,7 +215,7 @@
"metadata": {},
"outputs": [],
"source": [
"starfish.display(clipped, intensities)"
"starfish.display(clipped, decoded)"
]
},
{
Expand Down
27 changes: 15 additions & 12 deletions notebooks/py/STARmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import starfish
from starfish import IntensityTable
import starfish.data
from starfish.types import Axes
from starfish.types import Axes, TraceBuildingStrategies
from starfish.util.plot import (
diagnose_registration, imshow_plane, intensity_histogram
)
Expand Down Expand Up @@ -213,16 +213,20 @@ def plot_scaling_result(
# EPY: END markdown

# EPY: START code
lsbd = starfish.spots.DetectSpots.LocalSearchBlobDetector(
min_sigma=1,
max_sigma=8,
num_sigma=10,
threshold=np.percentile(np.ravel(stack.xarray.values), 95),
exclude_border=2,
anchor_round=0,
search_radius=10,
)
intensities = lsbd.run(scaled, n_processes=8)
bd = starfish.spots.FindSpots.BlobDetector(min_sigma=1,
max_sigma=8,
num_sigma=10,
threshold=np.percentile(np.ravel(stack.xarray.values), 95),
exclude_border=2)

spots = bd.run(scaled, n_processes=8)
decoder = starfish.spots.DecodeSpots.PerRoundMaxChannel(codebook=experiment.codebook,
anchor_round=0,
search_radius=10,
trace_building_strategy=
TraceBuildingStrategies.NEAREST_NEIGHBOR)

decoded = decoder.run(spots=spots)
# EPY: END code

# EPY: START markdown
Expand All @@ -234,7 +238,6 @@ def plot_scaling_result(
# EPY: END markdown

# EPY: START code
decoded = experiment.codebook.decode_per_round_max(IntensityTable(intensities.fillna(0)))
decode_mask = decoded['target'] != 'nan'

# %gui qt
Expand Down
26 changes: 14 additions & 12 deletions notebooks/py/SeqFISH.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@

import starfish
import starfish.data
from starfish.spots import DetectSpots
from starfish.types import Axes
from starfish.types import Axes, TraceBuildingStrategies
# EPY: END code

# EPY: START markdown
Expand Down Expand Up @@ -117,19 +116,22 @@
# EPY: START code
threshold = 0.5

lsbd = starfish.spots.DetectSpots.LocalSearchBlobDetector(
min_sigma=(1.5, 1.5, 1.5),
max_sigma=(8, 8, 8),
num_sigma=10,
threshold=threshold,
search_radius=7
)
intensities = lsbd.run(clipped)
decoded = exp.codebook.decode_per_round_max(intensities.fillna(0))
bd = starfish.spots.FindSpots.BlobDetector(min_sigma=(1.5, 1.5, 1.5),
max_sigma=(8, 8, 8),
num_sigma=10,
threshold=threshold)

spots = bd.run(clipped)
decoder = starfish.spots.DecodeSpots.PerRoundMaxChannel(codebook=exp.codebook,
search_radius=7,
trace_building_strategy=
TraceBuildingStrategies.NEAREST_NEIGHBOR)

decoded = decoder.run(spots=spots)
# EPY: END code

# EPY: START code
starfish.display(clipped, intensities)
starfish.display(clipped, decoded)
# EPY: END code

# EPY: START markdown
Expand Down
18 changes: 13 additions & 5 deletions starfish/core/spots/DecodeSpots/per_round_max_channel_decoder.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from typing import Callable

from starfish.core.codebook.codebook import Codebook
from starfish.core.intensity_table.decoded_intensity_table import DecodedIntensityTable
from starfish.core.intensity_table.intensity_table_coordinates import \
transfer_physical_coords_to_intensity_table
from starfish.core.spots.DecodeSpots.trace_builders import build_spot_traces_exact_match
from starfish.core.types import SpotFindingResults
from starfish.core.spots.DecodeSpots.trace_builders import trace_builders
from starfish.core.types import SpotFindingResults, TraceBuildingStrategies
from ._base import DecodeSpotsAlgorithm


Expand All @@ -23,8 +25,13 @@ class PerRoundMaxChannel(DecodeSpotsAlgorithm):
"""

def __init__(self, codebook: Codebook):
def __init__(self, codebook: Codebook, anchor_round: int=1, search_radius: int=3,
trace_building_strategy:
TraceBuildingStrategies=TraceBuildingStrategies.EXACT_MATCH):
self.codebook = codebook
self.trace_builder: Callable = trace_builders[trace_building_strategy]
self.anchor_round = anchor_round
self.search_radius = search_radius

def run(self, spots: SpotFindingResults, *args) -> DecodedIntensityTable:
"""Decode spots by selecting the max-valued channel in each sequencing round
Expand All @@ -40,7 +47,8 @@ def run(self, spots: SpotFindingResults, *args) -> DecodedIntensityTable:
IntensityTable decoded and appended with Features.TARGET and Features.QUALITY values.
"""
# if no spots
intensities = build_spot_traces_exact_match(spots)
intensities = self.trace_builder(spot_results=spots,
anchor_round=self.anchor_round,
search_radius=self.search_radius)
transfer_physical_coords_to_intensity_table(intensity_table=intensities, spots=spots)
return self.codebook.decode_per_round_max(intensities)
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
import numpy as np
from scipy.ndimage.filters import gaussian_filter

from starfish import ImageStack
from starfish.core.spots.DecodeSpots.trace_builders import build_traces_nearest_neighbors
from starfish.core.spots.FindSpots import BlobDetector
from starfish.core.types import Axes


def traversing_code() -> ImageStack:
"""this code walks in a sequential direction, and should only be detectable from some anchors"""
img = np.zeros((3, 2, 20, 50, 50), dtype=np.float32)

# code 1
img[0, 0, 5, 35, 35] = 10
img[1, 1, 5, 32, 32] = 10
img[2, 0, 5, 29, 29] = 10

# blur points
gaussian_filter(img, (0, 0, 0.5, 1.5, 1.5), output=img)

return ImageStack.from_numpy(img)


def multiple_possible_neighbors() -> ImageStack:
"""this image is intended to be tested with anchor_round in {0, 1}, last round has more spots"""
img = np.zeros((3, 2, 20, 50, 50), dtype=np.float32)

# round 1
img[0, 0, 5, 20, 40] = 10
img[0, 0, 5, 40, 20] = 10

# round 2
img[1, 1, 5, 20, 40] = 10
img[1, 1, 5, 40, 20] = 10

# round 3
img[2, 0, 5, 20, 40] = 10
img[2, 0, 5, 35, 35] = 10
img[2, 0, 5, 40, 20] = 10

# blur points
gaussian_filter(img, (0, 0, 0.5, 1.5, 1.5), output=img)

return ImageStack.from_numpy(img)


def jitter_code() -> ImageStack:
"""this code has some minor jitter <= 3px at the most distant point"""
img = np.zeros((3, 2, 20, 50, 50), dtype=np.float32)

# code 1
img[0, 0, 5, 35, 35] = 10
img[1, 1, 5, 34, 35] = 10
img[2, 0, 6, 35, 33] = 10

# blur points
gaussian_filter(img, (0, 0, 0.5, 1.5, 1.5), output=img)

return ImageStack.from_numpy(img)


def two_perfect_codes() -> ImageStack:
"""this code has no jitter"""
img = np.zeros((3, 2, 20, 50, 50), dtype=np.float32)

# code 1
img[0, 0, 5, 20, 35] = 10
img[1, 1, 5, 20, 35] = 10
img[2, 0, 5, 20, 35] = 10

# code 1
img[0, 0, 5, 40, 45] = 10
img[1, 1, 5, 40, 45] = 10
img[2, 0, 5, 40, 45] = 10

# blur points
gaussian_filter(img, (0, 0, 0.5, 1.5, 1.5), output=img)

return ImageStack.from_numpy(img)


def blob_detector():
return BlobDetector(min_sigma=1, max_sigma=4, num_sigma=30, threshold=.1)


def test_local_search_blob_detector_two_codes():
stack = two_perfect_codes()
bd = blob_detector()
spot_results = bd.run(stack, n_processes=1)

intensity_table = build_traces_nearest_neighbors(spot_results=spot_results, anchor_round=1,
search_radius=1)

assert intensity_table.shape == (2, 2, 3)
assert np.all(intensity_table[0][Axes.X.value] == 45)
assert np.all(intensity_table[0][Axes.Y.value] == 40)
assert np.all(intensity_table[0][Axes.ZPLANE.value] == 5)


def test_local_search_blob_detector_jitter_code():
stack = jitter_code()

bd = blob_detector()
spot_results = bd.run(stack, n_processes=1)
intensity_table = build_traces_nearest_neighbors(spot_results=spot_results, anchor_round=1,
search_radius=3)

assert intensity_table.shape == (1, 2, 3)
f, c, r = np.where(~intensity_table.isnull())
assert np.all(f == np.array([0, 0, 0]))
assert np.all(c == np.array([0, 0, 1]))
assert np.all(r == np.array([0, 2, 1]))

# test again with smaller search radius
bd = BlobDetector(min_sigma=1, max_sigma=4, num_sigma=30, threshold=.1)
per_tile_spot_results = bd.run(stack, n_processes=1)

intensity_table = build_traces_nearest_neighbors(spot_results=per_tile_spot_results,
anchor_round=0,
search_radius=1)

assert intensity_table.shape == (1, 2, 3)
f, c, r = np.where(~intensity_table.isnull())
assert np.all(f == np.array([0]))
assert np.all(c == np.array([0]))
assert np.all(r == np.array([0]))


def test_local_search_blob_detector_traversing_code():
stack = traversing_code()

bd = blob_detector()
spot_results = bd.run(stack, n_processes=1)
intensity_table = build_traces_nearest_neighbors(spot_results=spot_results, anchor_round=0,
search_radius=5)

assert intensity_table.shape == (1, 2, 3)
f, c, r = np.where(~intensity_table.isnull())
assert np.all(f == np.array([0, 0]))
assert np.all(c == np.array([0, 1]))
assert np.all(r == np.array([0, 1]))

bd = blob_detector()
spot_results = bd.run(stack, n_processes=1)
intensity_table = build_traces_nearest_neighbors(spot_results=spot_results, anchor_round=1,
search_radius=5)

f, c, r = np.where(~intensity_table.isnull())
assert np.all(f == np.array([0, 0, 0]))
assert np.all(c == np.array([0, 0, 1]))
assert np.all(r == np.array([0, 2, 1]))


def test_local_search_blob_detector_multiple_neighbors():
stack = multiple_possible_neighbors()

bd = blob_detector()
spot_results = bd.run(stack, n_processes=1)
intensity_table = build_traces_nearest_neighbors(spot_results=spot_results, anchor_round=0,
search_radius=4)

assert intensity_table.shape == (2, 2, 3)
f, c, r = np.where(~intensity_table.isnull())
assert np.all(intensity_table[Axes.ZPLANE.value] == (5, 5))
assert np.all(intensity_table[Axes.Y.value] == (40, 20))
assert np.all(intensity_table[Axes.X.value] == (20, 40))
Loading

0 comments on commit f2148b4

Please sign in to comment.