diff --git a/omnigibson/sensors/vision_sensor.py b/omnigibson/sensors/vision_sensor.py index 301a106b5..1ebdc2729 100644 --- a/omnigibson/sensors/vision_sensor.py +++ b/omnigibson/sensors/vision_sensor.py @@ -329,8 +329,6 @@ def _remap_modality(self, modality, obs, info, raw_obs): obs[modality], info[modality] = self._remap_instance_segmentation( obs[modality], id_to_labels, - obs["seg_semantic"], - info["seg_semantic"], id=(modality == "seg_instance_id"), ) elif "bbox" in modality: @@ -387,7 +385,7 @@ def _remap_semantic_segmentation(self, img, id_to_labels): return VisionSensor.SEMANTIC_REMAPPER.remap(replicator_mapping, semantic_class_id_to_name(), img, image_keys) - def _remap_instance_segmentation(self, img, id_to_labels, semantic_img, semantic_labels, id=False): + def _remap_instance_segmentation(self, img, id_to_labels, id=False): """ Remap the instance segmentation image to our own instance IDs. Also, correct the id_to_labels input with our new labels and return it. @@ -395,8 +393,6 @@ def _remap_instance_segmentation(self, img, id_to_labels, semantic_img, semantic Args: img (th.tensor): Instance segmentation image to remap id_to_labels (dict): Dictionary of instance IDs to class labels - semantic_img (th.tensor): Semantic segmentation image to use for instance registry - semantic_labels (dict): Dictionary of semantic IDs to class labels id (bool): Whether to remap for instance ID segmentation Returns: th.tensor: Remapped instance segmentation image diff --git a/omnigibson/utils/vision_utils.py b/omnigibson/utils/vision_utils.py index 519a60c2c..303688fab 100644 --- a/omnigibson/utils/vision_utils.py +++ b/omnigibson/utils/vision_utils.py @@ -71,12 +71,14 @@ class Remapper: def __init__(self): self.key_array = th.empty(0, dtype=th.int32, device="cuda") # Initialize the key_array as empty self.known_ids = set() + self.unlabelled_ids = set() self.warning_printed = set() def clear(self): """Resets the key_array to empty.""" self.key_array = th.empty(0, dtype=th.int32, device="cuda") self.known_ids = set() + self.unlabelled_ids = set() def remap(self, old_mapping, new_mapping, image, image_keys=None): """ @@ -109,6 +111,15 @@ def remap(self, old_mapping, new_mapping, image, image_keys=None): # Copy the previous key array into the new key array self.key_array[: len(prev_key_array)] = prev_key_array + # Retrospectively inspect our cached ids against the old mapping and update the key array + updated_ids = set() + for unlabelled_id in self.unlabelled_ids: + if unlabelled_id in old_mapping and old_mapping[unlabelled_id] != "unlabelled": + # If an object was previously unlabelled but now has a label, we need to update the key array + updated_ids.add(unlabelled_id) + self.unlabelled_ids -= updated_ids + self.known_ids -= updated_ids + new_keys = old_mapping.keys() - self.known_ids if new_keys: self.known_ids.update(new_keys) @@ -118,6 +129,9 @@ def remap(self, old_mapping, new_mapping, image, image_keys=None): new_key = next((k for k, v in new_mapping.items() if v == label), None) assert new_key is not None, f"Could not find a new key for label {label} in new_mapping!" self.key_array[key] = new_key + if label == "unlabelled": + # Some objects in the image might be unlabelled first but later get a valid label later, so we keep track of them + self.unlabelled_ids.add(key) # For all the values that exist in the image but not in old_mapping.keys(), we map them to whichever key in # new_mapping that equals to 'unlabelled'. This is needed because some values in the image don't necessarily