Skip to content

Commit

Permalink
(DiamondLightSource/hyperion#863) Make new pin tip detection more sim…
Browse files Browse the repository at this point in the history
…ilar to old one
  • Loading branch information
DominicOram committed Jan 12, 2024
1 parent 02d4299 commit f3a915a
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 49 deletions.
4 changes: 2 additions & 2 deletions src/dodal/devices/areadetector/plugins/MXSC.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,8 @@ class MXSC(Device):
canny_upper_threshold = Component(EpicsSignal, "CannyUpper")
canny_lower_threshold = Component(EpicsSignal, "CannyLower")
close_ksize = Component(EpicsSignal, "CloseKsize")
sample_detection_scan_direction = Component(EpicsSignal, "ScanDirection")
sample_detection_min_tip_height = Component(EpicsSignal, "MinTipHeight")
scan_direction = Component(EpicsSignal, "ScanDirection")
min_tip_height = Component(EpicsSignal, "MinTipHeight")

top = Component(EpicsSignal, "Top")
bottom = Component(EpicsSignal, "Bottom")
Expand Down
40 changes: 25 additions & 15 deletions src/dodal/devices/oav/pin_image_recognition/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from typing import Optional, Tuple

import numpy as np
from bluesky.protocols import Descriptor, Readable, Reading
from bluesky.protocols import Descriptor, Reading
from numpy.typing import NDArray
from ophyd_async.core import Device, SignalR, SignalRW
from ophyd_async.core import SignalR, SignalRW, StandardReadable
from ophyd_async.epics.signal import epics_signal_r

from dodal.devices.oav.pin_image_recognition.utils import (
Expand All @@ -19,7 +19,7 @@
from dodal.log import LOGGER


class PinTipDetection(Readable, Device):
class PinTipDetection(StandardReadable):
"""
A device which will read a single frame from an on-axis view and use that frame
to calculate the pin-tip offset (in pixels) of that frame.
Expand All @@ -31,6 +31,8 @@ class PinTipDetection(Readable, Device):
then it will return (None, None).
"""

INVALID_POSITION = (None, None)

def __init__(self, prefix: str, name: str = ""):
self._prefix: str = prefix
self._name = name
Expand All @@ -43,7 +45,7 @@ def __init__(self, prefix: str, name: str = ""):
self.timeout: SignalRW[float] = create_soft_signal_rw(
float, "timeout", self.name
)
self.preprocess: SignalRW[int] = create_soft_signal_rw(
self.preprocess_operation: SignalRW[int] = create_soft_signal_rw(
int, "preprocess", self.name
)
self.preprocess_ksize: SignalRW[int] = create_soft_signal_rw(
Expand All @@ -52,10 +54,10 @@ def __init__(self, prefix: str, name: str = ""):
self.preprocess_iterations: SignalRW[int] = create_soft_signal_rw(
int, "preprocess_iterations", self.name
)
self.canny_upper: SignalRW[int] = create_soft_signal_rw(
self.canny_upper_threshold: SignalRW[int] = create_soft_signal_rw(
int, "canny_upper", self.name
)
self.canny_lower: SignalRW[int] = create_soft_signal_rw(
self.canny_lower_threshold: SignalRW[int] = create_soft_signal_rw(
int, "canny_lower", self.name
)
self.close_ksize: SignalRW[int] = create_soft_signal_rw(
Expand All @@ -70,6 +72,9 @@ def __init__(self, prefix: str, name: str = ""):
self.min_tip_height: SignalRW[int] = create_soft_signal_rw(
int, "min_tip_height", self.name
)
self.validity_timeout: SignalR[float] = create_soft_signal_rw(
float, "validity_timeout", self.name
)

super().__init__(name=name)

Expand All @@ -82,7 +87,7 @@ async def _get_tip_position(
Returns tuple of:
((tip_x, tip_y), timestamp)
"""
preprocess_key = await self.preprocess.get_value()
preprocess_key = await self.preprocess_operation.get_value()
preprocess_iter = await self.preprocess_iterations.get_value()
preprocess_ksize = await self.preprocess_ksize.get_value()

Expand All @@ -94,13 +99,19 @@ async def _get_tip_position(
LOGGER.error("Invalid preprocessing function, using identity")
preprocess_func = identity()

direction = (
ScanDirections.FORWARD
if await self.scan_direction.get_value() == 0
else ScanDirections.REVERSE
)

sample_detection = MxSampleDetect(
preprocess=preprocess_func,
canny_lower=await self.canny_lower.get_value(),
canny_upper=await self.canny_upper.get_value(),
canny_lower=await self.canny_lower_threshold.get_value(),
canny_upper=await self.canny_upper_threshold.get_value(),
close_ksize=await self.close_ksize.get_value(),
close_iterations=await self.close_iterations.get_value(),
scan_direction=await self.scan_direction.get_value(),
scan_direction=direction,
min_tip_height=await self.min_tip_height.get_value(),
)

Expand All @@ -121,8 +132,7 @@ async def _get_tip_position(
tip_y = location.tip_y
except Exception as e:
LOGGER.error(f"Failed to detect pin-tip location due to exception: {e}")
tip_x = None
tip_y = None
tip_x, tip_y = self.INVALID_POSITION

return (tip_x, tip_y), timestamp

Expand All @@ -131,13 +141,13 @@ async def connect(self, sim: bool = False):

# Set defaults for soft parameters
await self.timeout.set(10.0)
await self.canny_upper.set(100)
await self.canny_lower.set(50)
await self.canny_upper_threshold.set(100)
await self.canny_lower_threshold.set(50)
await self.close_iterations.set(5)
await self.close_ksize.set(5)
await self.scan_direction.set(ScanDirections.FORWARD.value)
await self.min_tip_height.set(5)
await self.preprocess.set(10) # Identity function
await self.preprocess_operation.set(10) # Identity function
await self.preprocess_iterations.set(5)
await self.preprocess_ksize.set(5)

Expand Down
20 changes: 6 additions & 14 deletions src/dodal/devices/oav/pin_image_recognition/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def __init__(
canny_lower: int = 50,
close_ksize: int = 5,
close_iterations: int = 5,
scan_direction: int = 1,
scan_direction: ScanDirections = ScanDirections.FORWARD,
min_tip_height: int = 5,
):
"""
Expand All @@ -126,7 +126,7 @@ def __init__(
canny_lower: lower threshold for canny edge detection
close_ksize: kernel size for "close" operation
close_iterations: number of iterations for "close" operation
scan_direction: +1 for left-to-right, -1 for right-to-left
scan_direction: ScanDirections.FORWARD for left-to-right, ScanDirections.REVERSE for right-to-left
min_tip_height: minimum height of pin tip
"""

Expand All @@ -135,14 +135,6 @@ def __init__(
self.canny_lower = canny_lower
self.close_ksize = close_ksize
self.close_iterations = close_iterations

if scan_direction not in [
ScanDirections.FORWARD.value,
ScanDirections.REVERSE.value,
]:
raise ValueError(
"Invalid scan direction, expected +1 for left-to-right or -1 for right-to-left"
)
self.scan_direction = scan_direction

self.min_tip_height = min_tip_height
Expand Down Expand Up @@ -221,7 +213,7 @@ def _locate_sample(self, edge_arr: np.ndarray) -> SampleLocation:
)

# Choose our starting point - i.e. first column with non-narrow width for positive scan, last one for negative scan.
if self.scan_direction == ScanDirections.FORWARD.value:
if self.scan_direction == ScanDirections.FORWARD:
start_column = int(column_indices_with_non_narrow_widths[0])
else:
start_column = int(column_indices_with_non_narrow_widths[-1])
Expand All @@ -230,20 +222,20 @@ def _locate_sample(self, edge_arr: np.ndarray) -> SampleLocation:

# Move backwards to where there were no edges at all...
while top[x] != NONE_VALUE:
x += -self.scan_direction
x += -self.scan_direction.value
if x == -1 or x == width:
# (In this case the sample is off the edge of the picture.)
LOGGER.warning(
"pin-tip detection: Pin tip may be outside image area - assuming at edge."
)
break
x += self.scan_direction # ...and forward one step. x is now at the tip.
x += self.scan_direction.value # ...and forward one step. x is now at the tip.

tip_x = x
tip_y = int(round(0.5 * (top[x] + bottom[x])))

# clear edges to the left (right) of the tip.
if self.scan_direction == 1:
if self.scan_direction.value == 1:
top[:x] = NONE_VALUE
bottom[:x] = NONE_VALUE
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@ async def test_soft_parameter_defaults_are_correct():
device = await _get_pin_tip_detection_device()

assert await device.timeout.get_value() == 10.0
assert await device.canny_lower.get_value() == 50
assert await device.canny_upper.get_value() == 100
assert await device.canny_lower_threshold.get_value() == 50
assert await device.canny_upper_threshold.get_value() == 100
assert await device.close_ksize.get_value() == 5
assert await device.close_iterations.get_value() == 5
assert await device.min_tip_height.get_value() == 5
assert await device.scan_direction.get_value() == 1
assert await device.preprocess.get_value() == 10
assert await device.preprocess_operation.get_value() == 10
assert await device.preprocess_iterations.get_value() == 5
assert await device.preprocess_ksize.get_value() == 5

Expand All @@ -45,24 +45,24 @@ async def test_numeric_soft_parameters_can_be_changed():
device = await _get_pin_tip_detection_device()

await device.timeout.set(100.0)
await device.canny_lower.set(5)
await device.canny_upper.set(10)
await device.canny_lower_threshold.set(5)
await device.canny_upper_threshold.set(10)
await device.close_ksize.set(15)
await device.close_iterations.set(20)
await device.min_tip_height.set(25)
await device.scan_direction.set(-1)
await device.preprocess.set(2)
await device.preprocess_operation.set(2)
await device.preprocess_ksize.set(3)
await device.preprocess_iterations.set(4)

assert await device.timeout.get_value() == 100.0
assert await device.canny_lower.get_value() == 5
assert await device.canny_upper.get_value() == 10
assert await device.canny_lower_threshold.get_value() == 5
assert await device.canny_upper_threshold.get_value() == 10
assert await device.close_ksize.get_value() == 15
assert await device.close_iterations.get_value() == 20
assert await device.min_tip_height.get_value() == 25
assert await device.scan_direction.get_value() == -1
assert await device.preprocess.get_value() == 2
assert await device.preprocess_operation.get_value() == 2
assert await device.preprocess_ksize.get_value() == 3
assert await device.preprocess_iterations.get_value() == 4

Expand All @@ -71,7 +71,7 @@ async def test_numeric_soft_parameters_can_be_changed():
async def test_invalid_processing_func_uses_identity_function():
device = await _get_pin_tip_detection_device()

set_sim_value(device.preprocess, 50) # Invalid index
set_sim_value(device.preprocess_operation, 50) # Invalid index

with patch.object(
MxSampleDetect, "__init__", return_value=None
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import numpy as np
import pytest

from dodal.devices.oav.pin_image_recognition.utils import NONE_VALUE, MxSampleDetect
from dodal.devices.oav.pin_image_recognition.utils import (
NONE_VALUE,
MxSampleDetect,
ScanDirections,
)


def test_locate_sample_simple_forward():
Expand Down Expand Up @@ -44,9 +48,9 @@ def test_locate_sample_simple_reverse():
dtype=np.int32,
)

location = MxSampleDetect(min_tip_height=1, scan_direction=-1)._locate_sample(
test_arr
)
location = MxSampleDetect(
min_tip_height=1, scan_direction=ScanDirections.REVERSE
)._locate_sample(test_arr)

assert location.edge_top is not None
assert location.edge_bottom is not None
Expand Down Expand Up @@ -92,7 +96,9 @@ def test_locate_sample_no_edges():
assert location.tip_y is None


@pytest.mark.parametrize("direction,x_centre", [(1, 0), (-1, 4)])
@pytest.mark.parametrize(
"direction,x_centre", [(ScanDirections.FORWARD, 0), (ScanDirections.REVERSE, 4)]
)
def test_locate_sample_tip_off_screen(direction, x_centre):
test_arr = np.array(
[
Expand Down Expand Up @@ -147,9 +153,7 @@ def test_locate_sample_with_min_tip_height(
dtype=np.int32,
)

location = MxSampleDetect(
min_tip_height=min_tip_width, scan_direction=1
)._locate_sample(test_arr)
location = MxSampleDetect(min_tip_height=min_tip_width)._locate_sample(test_arr)

assert location.edge_top is not None
assert location.edge_bottom is not None
Expand Down

0 comments on commit f3a915a

Please sign in to comment.