Skip to content

Commit

Permalink
Inference CLI fixes (#459)
Browse files Browse the repository at this point in the history
Fix single instance inference and RGB detection.
  • Loading branch information
talmo authored Jan 18, 2021
1 parent 292446e commit cc057cb
Show file tree
Hide file tree
Showing 21 changed files with 178 additions and 71 deletions.
4 changes: 2 additions & 2 deletions sleap/nn/data/confidence_maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,9 +536,9 @@ def generate_confmaps(example):
if self.with_offsets:
example["offsets"] = mask_offsets(
make_offsets(
example["center_instance"], xv, yv, stride=self.output_stride
example["points"], xv, yv, stride=self.output_stride
),
example["instance_confidence_maps"],
example["confidence_maps"],
self.offsets_threshold,
)
if self.flatten_offsets:
Expand Down
5 changes: 5 additions & 0 deletions sleap/nn/data/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,11 +331,14 @@ class SingleInstanceConfmapsPipeline:
optimization_config: Optimization-related configuration.
single_instance_confmap_head: Instantiated head describing the output confidence
maps tensor.
offsets_head: Optional head describing the offset refinement maps.
"""

data_config: DataConfig
optimization_config: OptimizationConfig
single_instance_confmap_head: SingleInstanceConfmapsHead
offsets_head: Optional[OffsetRefinementHead] = None


def make_base_pipeline(self, data_provider: Provider) -> Pipeline:
"""Create base pipeline with input data only.
Expand Down Expand Up @@ -394,6 +397,8 @@ def make_training_pipeline(self, data_provider: Provider) -> Pipeline:
pipeline += SingleInstanceConfidenceMapGenerator(
sigma=self.single_instance_confmap_head.sigma,
output_stride=self.single_instance_confmap_head.output_stride,
with_offsets=self.offsets_head is not None,
offsets_threshold=self.offsets_head.sigma_threshold if self.offsets_head is not None else 1.0
)

if len(data_provider) >= self.optimization_config.batch_size:
Expand Down
2 changes: 1 addition & 1 deletion sleap/nn/data/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def make_dataset(self) -> tf.data.Dataset:
grid in order to properly map points to image coordinates.
"""
# Grab an image to test for the dtype.
test_image = tf.convert_to_tensor(self.video.get_frame(0))
test_image = tf.convert_to_tensor(self.video.get_frame(self.video.last_frame_idx))
image_dtype = test_image.dtype

def py_fetch_frame(ind):
Expand Down
25 changes: 18 additions & 7 deletions sleap/nn/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,6 +764,14 @@ def call(self, data):
threshold=self.peak_threshold,
)

# Adjust for stride and scale.
peaks = peaks * self.output_stride
if self.input_scale != 1.0:
# Note: We add 0.5 here to offset TensorFlow's weird image resizing. This
# may not always(?) be the most correct approach.
# See: https://github.com/tensorflow/tensorflow/issues/6720
peaks = (peaks / self.input_scale) + 0.5

out = {"peaks": peaks, "peak_vals": peak_vals}
if self.return_confmaps:
out["confmaps"] = cms
Expand Down Expand Up @@ -961,13 +969,16 @@ def _predict_generator(
# Run inference on current batch.
preds = self.inference_model.predict(ex)

# Convert to numpy arrays if not already.
if isinstance(preds["video_ind"], tf.Tensor):
preds["video_ind"] = preds["video_ind"].numpy().flatten()
if isinstance(preds["frame_ind"], tf.Tensor):
preds["frame_ind"] = preds["frame_ind"].numpy().flatten()
ex["peaks"] = preds["peaks"]
ex["peak_vals"] = preds["peak_vals"]

yield preds
# Convert to numpy arrays if not already.
if isinstance(ex["video_ind"], tf.Tensor):
ex["video_ind"] = ex["video_ind"].numpy().flatten()
if isinstance(ex["frame_ind"], tf.Tensor):
ex["frame_ind"] = ex["frame_ind"].numpy().flatten()

yield ex

def _make_labeled_frames_from_generator(
self, generator: Iterator[Dict[str, np.ndarray]], data_provider: Provider
Expand Down Expand Up @@ -2775,7 +2786,7 @@ def get_relevant_args(key):
)
elif "single_instance" in trained_model_paths:
predictor = SingleInstancePredictor.from_trained_models(
confmap_model_path=trained_model_paths["single_instance"],
trained_model_paths["single_instance"],
**get_relevant_args("single"),
**kwargs,
)
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
11 changes: 0 additions & 11 deletions tests/data/models/minimal_instance.UNet.centroid2/training_log.csv

This file was deleted.

Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -26,36 +26,36 @@
"leap": null,
"unet": {
"stem_stride": null,
"max_stride": 8,
"output_stride": 2,
"filters": 16,
"filters_rate": 1.5,
"max_stride": 4,
"output_stride": 4,
"filters": 8,
"filters_rate": 2.0,
"middle_block": true,
"up_interpolate": false,
"up_interpolate": true,
"stacks": 1
},
"hourglass": null,
"resnet": null,
"pretrained_encoder": null
},
"heads": {
"single_instance": null,
"centroid": {
"anchor_part": null,
"single_instance": {
"part_names": null,
"sigma": 5.0,
"output_stride": 2,
"offset_refinement": true
"output_stride": 4,
"offset_refinement": false
},
"centroid": null,
"centered_instance": null,
"multi_instance": null
}
},
"optimization": {
"preload_data": false,
"preload_data": true,
"augmentation_config": {
"rotate": false,
"rotation_min_angle": -180,
"rotation_max_angle": 180,
"rotation_min_angle": -180.0,
"rotation_max_angle": 180.0,
"translate": false,
"translate_min": -5,
"translate_max": 5,
Expand Down Expand Up @@ -83,12 +83,12 @@
"prefetch": true,
"batch_size": 4,
"batches_per_epoch": null,
"min_batches_per_epoch": 100,
"min_batches_per_epoch": 200,
"val_batches_per_epoch": null,
"min_val_batches_per_epoch": 1,
"epochs": 10,
"min_val_batches_per_epoch": 10,
"epochs": 100,
"optimizer": "adam",
"initial_learning_rate": 0.0001,
"initial_learning_rate": 0.001,
"learning_rate_schedule": {
"reduce_on_plateau": true,
"reduction_factor": 0.5,
Expand All @@ -112,11 +112,13 @@
},
"outputs": {
"save_outputs": true,
"run_name": "minimal_instance.UNet.centroid",
"run_name": "minimal_robot.UNet.single_instance",
"run_name_prefix": "",
"run_name_suffix": null,
"runs_folder": "models",
"tags": [],
"run_name_suffix": "",
"runs_folder": "",
"tags": [
""
],
"save_visualizations": false,
"log_to_csv": true,
"checkpointing": {
Expand All @@ -134,10 +136,10 @@
"visualizations": true
},
"zmq": {
"subscribe_to_controller": false,
"subscribe_to_controller": true,
"controller_address": "tcp://127.0.0.1:9000",
"controller_polling_timeout": 10,
"publish_updates": false,
"publish_updates": true,
"publish_address": "tcp://127.0.0.1:9001"
}
}
Expand Down
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"data": {
"labels": {
"training_labels": null,
"training_labels": "tests/data/slp_hdf5/small_robot_minimal.slp",
"validation_labels": null,
"validation_fraction": 0.1,
"test_labels": null,
Expand Down Expand Up @@ -70,7 +70,7 @@
"ensure_grayscale": false,
"imagenet_mode": null,
"input_scaling": 0.5,
"pad_to_stride": 8
"pad_to_stride": 4
},
"instance_cropping": {
"center_on_part": null,
Expand All @@ -83,36 +83,39 @@
"leap": null,
"unet": {
"stem_stride": null,
"max_stride": 8,
"output_stride": 2,
"filters": 16,
"filters_rate": 1.5,
"max_stride": 4,
"output_stride": 4,
"filters": 8,
"filters_rate": 2.0,
"middle_block": true,
"up_interpolate": false,
"up_interpolate": true,
"stacks": 1
},
"hourglass": null,
"resnet": null,
"pretrained_encoder": null
},
"heads": {
"single_instance": null,
"centroid": {
"anchor_part": null,
"single_instance": {
"part_names": [
"A",
"B"
],
"sigma": 5.0,
"output_stride": 2,
"offset_refinement": true
"output_stride": 4,
"offset_refinement": false
},
"centroid": null,
"centered_instance": null,
"multi_instance": null
}
},
"optimization": {
"preload_data": false,
"preload_data": true,
"augmentation_config": {
"rotate": false,
"rotation_min_angle": -180,
"rotation_max_angle": 180,
"rotation_min_angle": -180.0,
"rotation_max_angle": 180.0,
"translate": false,
"translate_min": -5,
"translate_max": 5,
Expand All @@ -139,13 +142,13 @@
"shuffle_buffer_size": 128,
"prefetch": true,
"batch_size": 4,
"batches_per_epoch": 100,
"min_batches_per_epoch": 100,
"val_batches_per_epoch": 1,
"min_val_batches_per_epoch": 1,
"epochs": 10,
"batches_per_epoch": 200,
"min_batches_per_epoch": 200,
"val_batches_per_epoch": 10,
"min_val_batches_per_epoch": 10,
"epochs": 100,
"optimizer": "adam",
"initial_learning_rate": 0.0001,
"initial_learning_rate": 0.001,
"learning_rate_schedule": {
"reduce_on_plateau": true,
"reduction_factor": 0.5,
Expand All @@ -169,11 +172,13 @@
},
"outputs": {
"save_outputs": true,
"run_name": "minimal_instance.UNet.centroid",
"run_name": "minimal_robot.UNet.single_instance",
"run_name_prefix": "",
"run_name_suffix": "",
"runs_folder": "models",
"tags": [],
"runs_folder": "",
"tags": [
""
],
"save_visualizations": false,
"log_to_csv": true,
"checkpointing": {
Expand All @@ -191,10 +196,10 @@
"visualizations": true
},
"zmq": {
"subscribe_to_controller": false,
"subscribe_to_controller": true,
"controller_address": "tcp://127.0.0.1:9000",
"controller_polling_timeout": 10,
"publish_updates": false,
"publish_updates": true,
"publish_address": "tcp://127.0.0.1:9001"
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
epoch,loss,lr,val_loss
0,0.0060239629819989204,0.001,0.0020105133298784494
1,0.0006304828566499054,0.001,0.001671794569119811
2,0.000706226215697825,0.001,0.0021381149999797344
3,0.0002680201141629368,0.001,0.0015851699281483889
4,0.00021906329493504018,0.001,0.0016144101973623037
5,0.00013519266212824732,0.001,0.0015455052489414811
6,9.628612315282226e-05,0.001,0.0015609721885994077
7,0.00012194672308396548,0.001,0.0016072376165539026
8,0.00011606428597588092,0.001,0.0015309698646888137
9,7.172523328335956e-05,0.001,0.0016054341103881598
10,6.752563785994425e-05,0.001,0.0017311170231550932
11,6.967845547478646e-05,0.001,0.0016062406357377768
Binary file not shown.
5 changes: 5 additions & 0 deletions tests/fixtures/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
TEST_JSON_MIN_LABELS = "tests/data/json_format_v2/minimal_instance.json"
TEST_SLP_MIN_LABELS = "tests/data/slp_hdf5/minimal_instance.slp"
TEST_MAT_LABELS = "tests/data/mat/labels.mat"
TEST_SLP_MIN_LABELS_ROBOT = "tests/data/slp_hdf5/small_robot_minimal.slp"


@pytest.fixture
Expand All @@ -40,6 +41,10 @@ def min_labels_slp():
return Labels.load_file(TEST_SLP_MIN_LABELS)


@pytest.fixture
def min_labels_robot():
return Labels.load_file(TEST_SLP_MIN_LABELS_ROBOT)

@pytest.fixture
def mat_labels():
return Labels.load_leap_matlab(TEST_MAT_LABELS, gui=False)
Expand Down
4 changes: 4 additions & 0 deletions tests/fixtures/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,7 @@ def min_centered_instance_model_path():
@pytest.fixture
def min_bottomup_model_path():
return "tests/data/models/minimal_instance.UNet.bottomup"

@pytest.fixture
def min_single_instance_robot_model_path():
return "tests/data/models/minimal_robot.UNet.single_instance"
Loading

0 comments on commit cc057cb

Please sign in to comment.