Skip to content

Commit

Permalink
ARRUS-109: Updated Python remapping code for batch of sequences of RF…
Browse files Browse the repository at this point in the history
… frames. (#244)
  • Loading branch information
pjarosik authored Nov 10, 2021
1 parent 82268fd commit 58a6989
Show file tree
Hide file tree
Showing 19 changed files with 626 additions and 283 deletions.
5 changes: 5 additions & 0 deletions api/python/arrus/devices/us4r.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,14 @@ class FrameChannelMapping:
:param frames: a mapping: (logical frame, logical channel) -> physical frame
:param channels: a mapping: (logical frame, logical channel) -> physical channel
:param us4oems: a mapping: (logical frame, logical channel) -> us4OEM number
:param frame_offsets: frame starting number for each us4OEM available in the system
:param batch_size: number of sequences in a single batch
"""
frames: np.ndarray
channels: np.ndarray
us4oems: np.ndarray
frame_offsets: np.ndarray
batch_size: int = 1


Expand Down
3 changes: 2 additions & 1 deletion api/python/arrus/kernels/imaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ def process_simple_tx_rx_sequence(context):
rx = Rx(rx_aperture, sample_range, op.downsampling_factor,
padding=rx_padding)
txrx.append(TxRx(tx, rx, op.pri))
return TxRxSequence(txrx, tgc_curve=tgc_curve, sri=op.sri)
return TxRxSequence(txrx, tgc_curve=tgc_curve, sri=op.sri,
n_repeats=op.n_repeats)


def get_aperture_center(tx_aperture_center_element, probe):
Expand Down
3 changes: 3 additions & 0 deletions api/python/arrus/ops/imaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ class SimpleTxRxSequence:
:param sri: sequence repetition interval - the time between consecutive RF \
frames. When None, the time between consecutive RF frames is determined \
by the total pri only. [s]
:param n_repeats: size of a single batch -- how many times this sequence should be \
repeated before data is transferred to computer (integer)
"""
pulse: arrus.ops.us4r.Pulse
rx_sample_range: tuple
Expand All @@ -78,6 +80,7 @@ class SimpleTxRxSequence:
tgc_start: float = None
tgc_slope: float = None
tgc_curve: list = None
n_repeats: int = 1

def __post_init__(self):
# Validation
Expand Down
17 changes: 10 additions & 7 deletions api/python/arrus/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,14 @@ def upload(self, scheme: arrus.ops.us4r.Scheme):
###
# -- Constant metadata
# --- FCM
fcm_frame, fcm_channel = arrus.utils.core.convert_fcm_to_np_arrays(fcm)
fcm_us4oems, fcm_frame, fcm_channel, frame_offsets = \
arrus.utils.core.convert_fcm_to_np_arrays(fcm, us_device.n_us4oems)
fcm = arrus.devices.us4r.FrameChannelMapping(
frames=fcm_frame, channels=fcm_channel, batch_size=1)
us4oems=fcm_us4oems,
frames=fcm_frame,
channels=fcm_channel,
frame_offsets=frame_offsets,
batch_size=batch_size)

# --- Frame acquisition context
fac = self._create_frame_acquisition_context(seq, raw_seq, us_device_dto, medium)
Expand All @@ -121,9 +126,9 @@ def upload(self, scheme: arrus.ops.us4r.Scheme):
"Currently only a sequence with constant number of samples "
"can be accepted.")
n_samples = next(iter(n_samples))
input_shape = self._get_physical_frame_shape(fcm, n_samples, rx_batch_size=batch_size)

buffer = arrus.framework.DataBuffer(buffer_handle)
input_shape = buffer.elements[0].data.shape

const_metadata = arrus.metadata.ConstMetadata(
context=fac, data_desc=echo_data_description,
Expand All @@ -137,14 +142,13 @@ def upload(self, scheme: arrus.ops.us4r.Scheme):
raise ValueError("Currently only arrus.utils.imaging.Pipeline "
"processing is supported only.")
import cupy as cp

out_metadata = processing.prepare(const_metadata)
self.gpu_buffer = arrus.utils.imaging.Buffer(n_elements=4,
self.gpu_buffer = arrus.utils.imaging.Buffer(n_elements=2,
shape=const_metadata.input_shape,
dtype=const_metadata.dtype,
math_pkg=cp,
type="locked")
self.out_buffer = [arrus.utils.imaging.Buffer(n_elements=4,
self.out_buffer = [arrus.utils.imaging.Buffer(n_elements=2,
shape=m.input_shape,
dtype=m.dtype, math_pkg=np,
type="locked")
Expand All @@ -167,7 +171,6 @@ def buffer_callback(elements):
print(f"Exception: {type(e)}")
except:
print("Unknown exception")

pipeline_wrapper = arrus.utils.imaging.PipelineRunner(
buffer, self.gpu_buffer, self.out_buffer, processing,
buffer_callback)
Expand Down
16 changes: 12 additions & 4 deletions api/python/arrus/utils/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,27 +62,35 @@ def convert_to_core_sequence(seq):
return core_seq


def convert_fcm_to_np_arrays(fcm):
def convert_fcm_to_np_arrays(fcm, n_us4oems):
"""
Converts frame channel mapping to a tupple of numpy arrays.
:param fcm: arrus.core.FrameChannelMapping
:return: a pair of numpy arrays: fcm_frame, fcm_channel
"""
fcm_us4oem = np.zeros(
(fcm.getNumberOfLogicalFrames(), fcm.getNumberOfLogicalChannels()),
dtype=np.uint8)
fcm_frame = np.zeros(
(fcm.getNumberOfLogicalFrames(), fcm.getNumberOfLogicalChannels()),
dtype=np.int16)
fcm_channel = np.zeros(
(fcm.getNumberOfLogicalFrames(), fcm.getNumberOfLogicalChannels()),
dtype=np.int8)
frame_offsets = np.zeros(n_us4oems, dtype=np.uint32)
for frame in range(fcm.getNumberOfLogicalFrames()):
for channel in range(fcm.getNumberOfLogicalChannels()):
frame_channel = fcm.getLogical(frame, channel)
src_frame = frame_channel[0]
src_channel = frame_channel[1]
src_us4oem = frame_channel.getUs4oem()
src_frame = frame_channel.getFrame()
src_channel = frame_channel.getChannel()
fcm_us4oem[frame, channel] = src_us4oem
fcm_frame[frame, channel] = src_frame
fcm_channel[frame, channel] = src_channel
return fcm_frame, fcm_channel
frame_offsets = [fcm.getFirstFrame(i) for i in range(n_us4oems)]
frame_offsets = np.array(frame_offsets, dtype=np.uint32)
return fcm_us4oem, fcm_frame, fcm_channel, frame_offsets


def convert_to_py_probe_model(core_model):
Expand Down
187 changes: 164 additions & 23 deletions api/python/arrus/utils/imaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import cupy
if cupy.__version__ < "9.0.0":
raise Exception(f"The version of cupy module is too low. "
f"Try install the version ''9.0.0'' or higher.")
f"Use version ''9.0.0'' or higher.")


def get_extent(x_grid, z_grid):
Expand Down Expand Up @@ -272,7 +272,8 @@ def process(self, data):
outputs.appendleft(output)
else:
data = step.process(data)
outputs.appendleft(data)
if not self._is_last_endpoint:
outputs.appendleft(data)
return outputs

def __initialize(self, const_metadata):
Expand Down Expand Up @@ -303,6 +304,9 @@ def prepare(self, const_metadata):
self.__initialize(const_metadata)
if not isinstance(self.steps[-1], Pipeline):
metadatas.appendleft(current_metadata)
self._is_last_endpoint = False
else:
self._is_last_endpoint = True
return metadatas

def set_placement(self, device):
Expand Down Expand Up @@ -1189,6 +1193,131 @@ def _put_ignore_full(self, data):
return data


class SelectSequenceRaw(Operation):

def __init__(self, sequence):
if isinstance(sequence, Iterable) and len(sequence) > 1:
raise ValueError("Only a single sequence can be selected")
self.sequence = sequence
self.output = None
self.num_pkg = None
self.positions = None

def set_pkgs(self, num_pkg, **kwargs):
self.num_pkg = num_pkg

def prepare(self, const_metadata):
context = const_metadata.context
seq = context.sequence
raw_seq = context.raw_sequence
n_seq = len(self.sequence)

# For each us4oem, compute tuples: (src_start, dst_start, src_end, dst_end)
# Where each value is the number of rows (we assume 32 columns, i.e. RX channels)
n_samples_set = {op.rx.get_n_samples() for op in raw_seq.ops}

if len(n_samples_set) > 1:
raise arrus.exceptions.IllegalArgumentError(
f"Each tx/rx in the sequence should acquire the same number of "
f"samples (actual: {n_samples_set})")
n_samples = next(iter(n_samples_set))

fcm = const_metadata.data_description.custom["frame_channel_mapping"]
fcm_us4oems = fcm.us4oems
fcm_frames = fcm.frames
# TODO update frame offsets
us4oems = set(fcm.us4oems.flatten().tolist())
sorted(us4oems)

self.positions = []
dst_start = 0
dst_end = 0
frame_offsets = []
current_frame = 0 # Current physical frame.
for us4oem in us4oems:
n_frames = self.num_pkg.max(fcm_frames[fcm_us4oems == us4oem])+1
us4oem_offset = fcm.frame_offsets[us4oem]
# NOTE: below we use only a single sequence
src_start = us4oem_offset*n_samples+self.sequence[0]*n_frames*n_samples
src_end = src_start+n_frames*n_samples
dst_end = dst_start+n_frames*n_samples
self.positions.append((src_start, dst_start, src_end, dst_end))
frame_offsets.append(current_frame)
current_frame += n_frames
dst_start = dst_end

output_shape = (dst_end, 32)
self.output = self.num_pkg.zeros(output_shape, dtype=np.int16)

# Update const metadata
new_seq = dataclasses.replace(seq, n_repeats=n_seq)
new_raw_seq = dataclasses.replace(raw_seq, n_repeats=n_seq)
new_context = arrus.metadata.FrameAcquisitionContext(
device=context.device, sequence=new_seq,
raw_sequence=new_raw_seq, medium=context.medium,
custom_data=context.custom_data)

# Update FCM (change the batch_size)
data_desc = const_metadata.data_description
data_desc_custom = data_desc.custom
new_data_desc_custom = data_desc_custom.copy()
fcm = data_desc_custom["frame_channel_mapping"]
new_fcm = dataclasses.replace(fcm, batch_size=1,
frame_offsets=frame_offsets)
new_data_desc_custom["frame_channel_mapping"] = new_fcm
new_data_desc = dataclasses.replace(data_desc, custom=new_data_desc_custom)

return const_metadata.copy(input_shape=output_shape,
context=new_context,
data_desc=new_data_desc)

def process(self, data):
for src_start, dst_start, src_end, dst_end in self.positions:
self.output[dst_start:dst_end, :] = data[src_start:src_end, :]
return self.output


class SelectSequence(Operation):
"""
Selects sequences for a given batch for further processing.
This operator modifies input context so the appropriate
number of sequences is properly set.
:param frames: sequences to select
"""

def __init__(self, sequence):
if not isinstance(sequence, Iterable):
# Wrap into an array
sequence = [sequence]
self.sequence = sequence

def set_pkgs(self, **kwargs):
pass

def prepare(self, const_metadata):
input_shape = const_metadata.input_shape
context = const_metadata.context
seq = context.sequence
raw_seq = context.raw_sequence
n_seq = len(self.sequence)

output_shape = input_shape[1:]
output_shape = (n_seq, ) + output_shape
new_seq = dataclasses.replace(seq, n_repeats=n_seq)
new_raw_seq = dataclasses.replace(raw_seq, n_repeats=n_seq)
new_context = arrus.metadata.FrameAcquisitionContext(
device=context.device, sequence=new_seq,
raw_sequence=new_raw_seq, medium=context.medium,
custom_data=context.custom_data)
return const_metadata.copy(input_shape=output_shape,
context=new_context)

def process(self, data):
return data[self.sequence]


class SelectFrames(Operation):
"""
Selects frames for a given sequence for further processing.
Expand Down Expand Up @@ -1252,11 +1381,15 @@ def process(self, data):

def _limit_params(self, value, frames):
if value is not None and hasattr(value, "__len__") and len(value) > 1:
return value[frames]
return np.array(value)[frames]
else:
return value


# Alias
SelectFrame = SelectFrames


class Squeeze(Operation):
"""
Squeezes input array (removes axes = 1).
Expand Down Expand Up @@ -1739,44 +1872,52 @@ def prepare(self, const_metadata: arrus.metadata.ConstMetadata):
n_frames, n_channels = fcm.frames.shape
n_samples_set = {op.rx.get_n_samples()
for op in const_metadata.context.raw_sequence.ops}

# get (unique) number of samples in a frame
if len(n_samples_set) > 1:
raise arrus.exceptions.IllegalArgumentError(
f"Each tx/rx in the sequence should acquire the same number of "
f"samples (actual: {n_samples_set})")
n_samples = next(iter(n_samples_set))
self.output_shape = (n_frames, n_samples, n_channels)
batch_size = fcm.batch_size
self.output_shape = (batch_size, n_frames, n_samples, n_channels)
self._output_buffer = xp.zeros(shape=self.output_shape, dtype=xp.int16)

n_samples_raw, n_channels_raw = const_metadata.input_shape
self._input_shape = (n_samples_raw//n_samples, n_samples,
n_channels_raw)
self.batch_size = fcm.batch_size

if xp == np:
# CPU
self._transfers = __group_transfers(fcm)
def cpu_remap_fn(data):
__remap(self._output_buffer,
data.reshape(self._input_shape),
transfers=self._transfers)
self._remap_fn = cpu_remap_fn
raise ValueError(f"'{type(self).__name__}' is not implemented for CPU")
else:
# GPU
import cupy as cp
from arrus.utils.us4r_remap_gpu import get_default_grid_block_size, run_remap
self._fcm_frames = cp.asarray(fcm.frames)
self._fcm_channels = cp.asarray(fcm.channels)
self.grid_size, self.block_size = get_default_grid_block_size(self._fcm_frames, n_samples)

self._fcm_us4oems = cp.asarray(fcm.us4oems)
frame_offsets = fcm.frame_offsets
# TODO constant memory
self._frame_offsets = cp.asarray(frame_offsets)
# For each us4OEM, get number of physical frames this us4OEM gathers.
# Note: this is the max number of us4OEM IN USE.
n_us4oems = cp.max(self._fcm_us4oems).get()+1
n_frames_us4oems = []
for us4oem in range(n_us4oems):
n_frames_us4oem = cp.max(self._fcm_frames[self._fcm_us4oems == us4oem])
n_frames_us4oems.append(n_frames_us4oem)

# TODO constant memory
self._n_frames_us4oems = cp.asarray(n_frames_us4oems, dtype=cp.uint32)+1
self.grid_size, self.block_size = get_default_grid_block_size(
self._fcm_frames, n_samples,
batch_size
)
def gpu_remap_fn(data):
run_remap(
self.grid_size, self.block_size,
run_remap(self.grid_size, self.block_size,
[self._output_buffer, data,
self._fcm_frames, self._fcm_channels,
n_frames, n_samples, n_channels])
self._fcm_frames, self._fcm_channels, self._fcm_us4oems,
self._frame_offsets,
self._n_frames_us4oems,
batch_size, n_frames, n_samples, n_channels])

self._remap_fn = gpu_remap_fn

return const_metadata.copy(input_shape=self.output_shape)

def process(self, data):
Expand Down
Loading

0 comments on commit 58a6989

Please sign in to comment.