Skip to content

Commit

Permalink
Set max instances for top down models (#1070)
Browse files Browse the repository at this point in the history
* Add optional unragging arg to model export

* Add option to set max instances for multi-instance models

* Fix test
  • Loading branch information
sheridana authored Dec 8, 2022
1 parent 5956782 commit eac2e2b
Show file tree
Hide file tree
Showing 2 changed files with 162 additions and 6 deletions.
126 changes: 120 additions & 6 deletions sleap/nn/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,7 @@ def export_model(
model_name: Optional[str] = None,
tensors: Optional[Dict[str, str]] = None,
unrag_outputs: bool = True,
max_instances: Optional[int] = None,
):

"""Export a trained SLEAP model as a frozen graph. Initializes model,
Expand All @@ -470,10 +471,20 @@ def export_model(
sleap.nn.data.utils.describe_tensors as an example)
unrag_outputs: If `True` (default), any ragged tensors will be
converted to normal tensors and padded with NaNs
max_instances: If set, determines the max number of instances that a
multi-instance model returns. This is enforced during centroid
cropping and therefore only compatible with TopDown models.
"""

self._initialize_inference_model()
predictor_name = type(self).__name__

if max_instances is not None:
if "TopDown" in predictor_name:
print(f"\n max instances set, limiting instances to {max_instances} \n")
self.inference_model.centroid_crop.max_instances = max_instances
else:
raise Exception(f"{predictor_name} does not support max instance limit")

first_inference_layer = self.inference_model.layers[0]
keras_model_shape = first_inference_layer.keras_model.input.shape
Expand Down Expand Up @@ -1469,10 +1480,17 @@ def export_model(
model_name: Optional[str] = None,
tensors: Optional[Dict[str, str]] = None,
unrag_outputs: bool = True,
max_instances: Optional[int] = None,
):

super().export_model(
save_path, signatures, save_traces, model_name, tensors, unrag_outputs
save_path,
signatures,
save_traces,
model_name,
tensors,
unrag_outputs,
max_instances,
)

self.confmap_config.save_json(os.path.join(save_path, "confmap_config.json"))
Expand Down Expand Up @@ -1523,6 +1541,8 @@ class CentroidCrop(InferenceLayer):
the predicted peaks. This is true by default since crops are used
for finding instance peaks in a top down model. If using a centroid
only inference model, this should be set to `False`.
max_instances: If set, determines the max number of instances that a
multi-instance model returns.
"""

def __init__(
Expand All @@ -1539,6 +1559,7 @@ def __init__(
confmaps_ind: Optional[int] = None,
offsets_ind: Optional[int] = None,
return_crops: bool = True,
max_instances: Optional[int] = None,
**kwargs,
):
super().__init__(
Expand Down Expand Up @@ -1576,6 +1597,7 @@ def __init__(
self.integral_patch_size = integral_patch_size
self.return_confmaps = return_confmaps
self.return_crops = return_crops
self.max_instances = max_instances

@tf.function
def call(self, inputs):
Expand Down Expand Up @@ -1669,9 +1691,67 @@ def call(self, inputs):
# Store crop offsets.
crop_offsets = centroid_points - (self.crop_size / 2)

samples = tf.shape(imgs)[0]

n_peaks = tf.shape(centroid_points)[0]

if n_peaks > 0:

if self.max_instances is not None:

centroid_points = tf.RaggedTensor.from_value_rowids(
centroid_points, crop_sample_inds, nrows=samples
)
centroid_vals = tf.RaggedTensor.from_value_rowids(
centroid_vals, crop_sample_inds, nrows=samples
)

_centroid_vals = tf.TensorArray(
size=samples,
dtype=tf.float32,
infer_shape=False,
element_shape=[None],
)

_centroid_points = tf.TensorArray(
size=samples,
dtype=tf.float32,
infer_shape=False,
element_shape=[None, 2],
)

_row_ids = tf.TensorArray(
size=samples,
dtype=tf.int32,
infer_shape=False,
element_shape=[None],
)

for sample in range(samples):

top_points = tf.math.top_k(
centroid_vals[sample], k=self.max_instances
)
top_inds = top_points.indices

_centroid_vals = _centroid_vals.write(
sample, tf.gather(centroid_vals[sample], top_inds)
)

_centroid_points = _centroid_points.write(
sample, tf.gather(centroid_points[sample], top_inds)
)

_row_ids = _row_ids.write(sample, tf.fill([len(top_inds)], sample))

centroid_vals = _centroid_vals.concat()
centroid_points = _centroid_points.concat()
crop_sample_inds = _row_ids.concat()

n_peaks = tf.shape(crop_sample_inds)[0]

crop_offsets = centroid_points - (self.crop_size / 2)

# Crop instances around centroids.
bboxes = sleap.nn.data.instance_cropping.make_centered_bboxes(
centroid_points, self.crop_size, self.crop_size
Expand All @@ -1684,6 +1764,7 @@ def call(self, inputs):
crops = tf.reshape(
crops, [n_peaks, self.crop_size, self.crop_size, full_imgs.shape[3]]
)

else:
# No peaks found, so just create a placeholder stack.
crops = tf.zeros(
Expand All @@ -1692,7 +1773,6 @@ def call(self, inputs):
)

# Group crops by sample (samples, ?, ...).
samples = tf.shape(imgs)[0]
centroids = tf.RaggedTensor.from_value_rowids(
centroid_points, crop_sample_inds, nrows=samples
)
Expand Down Expand Up @@ -2390,10 +2470,17 @@ def export_model(
model_name: Optional[str] = None,
tensors: Optional[Dict[str, str]] = None,
unrag_outputs: bool = True,
max_instances: Optional[int] = None,
):

super().export_model(
save_path, signatures, save_traces, model_name, tensors, unrag_outputs
save_path,
signatures,
save_traces,
model_name,
tensors,
unrag_outputs,
max_instances,
)

if self.confmap_config is not None:
Expand Down Expand Up @@ -4086,10 +4173,17 @@ def export_model(
model_name: Optional[str] = None,
tensors: Optional[Dict[str, str]] = None,
unrag_outputs: bool = True,
max_instances: Optional[int] = None,
):

super().export_model(
save_path, signatures, save_traces, model_name, tensors, unrag_outputs
save_path,
signatures,
save_traces,
model_name,
tensors,
unrag_outputs,
max_instances,
)

if self.confmap_config is not None:
Expand Down Expand Up @@ -4215,6 +4309,7 @@ def export_model(
model_name: Optional[str] = None,
tensors: Optional[Dict[str, str]] = None,
unrag_outputs: bool = True,
max_instances: Optional[int] = None,
):
"""High level export of a trained SLEAP model as a frozen graph.
Expand All @@ -4232,10 +4327,20 @@ def export_model(
sleap.nn.data.utils.describe_tensors as an example).
unrag_outputs: If `True` (default), any ragged tensors will be
converted to normal tensors and padded with NaNs
max_instances: If set, determines the max number of instances that a
multi-instance model returns. This is enforced during centroid
cropping and therefore only compatible with TopDown models.
"""
predictor = load_model(model_path)

predictor.export_model(
save_path, signatures, save_traces, model_name, tensors, unrag_outputs
save_path,
signatures,
save_traces,
model_name,
tensors,
unrag_outputs,
max_instances,
)


Expand Down Expand Up @@ -4273,6 +4378,15 @@ def export_cli():
"Defaults to True."
),
)
parser.add_argument(
"-m",
"--max_instances",
type=int,
help=(
"Limit maximum number of instances in multi-instance models"
"Defaults to None"
),
)

args, _ = parser.parse_known_args()
export_model(args.models, args.export_path, unrag_outputs=args.unrag)
Expand Down
42 changes: 42 additions & 0 deletions tests/nn/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,7 @@ def test_topdown_predictor_centroid(min_labels, min_centroid_model_path):
predictor = TopDownPredictor.from_trained_models(
centroid_model_path=min_centroid_model_path
)

predictor.verbosity = "none"
labels_pr = predictor.predict(min_labels)
assert len(labels_pr) == 1
Expand All @@ -603,13 +604,21 @@ def test_topdown_predictor_centroid(min_labels, min_centroid_model_path):
inds1, inds2 = sleap.nn.utils.match_points(points_gt, points_pr)
assert_allclose(points_gt[inds1.numpy()], points_pr[inds2.numpy()], atol=1.5)

# test max_instances (>2 will fail)
predictor.inference_model.centroid_crop.max_instances = 2
labels_pr = predictor.predict(min_labels)

assert len(labels_pr) == 1
assert len(labels_pr[0].instances) == 2


def test_topdown_predictor_centered_instance(
min_labels, min_centered_instance_model_path
):
predictor = TopDownPredictor.from_trained_models(
confmap_model_path=min_centered_instance_model_path
)

predictor.verbosity = "none"
labels_pr = predictor.predict(min_labels)
assert len(labels_pr) == 1
Expand Down Expand Up @@ -859,6 +868,14 @@ def test_centroid_inference():
assert preds["centroids"].shape == (1, 3, 2)
assert preds["centroid_vals"].shape == (1, 3)

# test max instances (>3 will fail)
layer.max_instances = 3
out = layer(cms)

model = CentroidInferenceModel(layer)

preds = model.predict(cms)


def export_frozen_graph(model, preds, output_path):

Expand Down Expand Up @@ -1008,6 +1025,15 @@ def test_single_instance_predictor_save(min_single_instance_robot_model_path, tm
unrag_outputs=False,
)

# max_instances should raise an exception for single instance
with pytest.raises(Exception):
export_model(
min_single_instance_robot_model_path,
save_path=tmp_path.as_posix(),
unrag_outputs=False,
max_instances=1,
)


def test_topdown_predictor_save(
min_centroid_model_path, min_centered_instance_model_path, tmp_path
Expand Down Expand Up @@ -1039,6 +1065,14 @@ def test_topdown_predictor_save(
unrag_outputs=False,
)

# test max instances
export_model(
[min_centroid_model_path, min_centered_instance_model_path],
save_path=tmp_path.as_posix(),
unrag_outputs=False,
max_instances=4,
)


def test_topdown_id_predictor_save(
min_centroid_model_path, min_topdown_multiclass_model_path, tmp_path
Expand Down Expand Up @@ -1070,6 +1104,14 @@ def test_topdown_id_predictor_save(
unrag_outputs=False,
)

# test max instances
export_model(
[min_centroid_model_path, min_topdown_multiclass_model_path],
save_path=tmp_path.as_posix(),
unrag_outputs=False,
max_instances=4,
)


@pytest.mark.parametrize(
"output_path,tracker_method", [("not_default", "flow"), (None, "simple")]
Expand Down

0 comments on commit eac2e2b

Please sign in to comment.