Skip to content

Commit

Permalink
Merge pull request #3019 from Rusteam/sam
Browse files Browse the repository at this point in the history
Add segment-anything models (SAM) to model zoo
  • Loading branch information
brimoor authored Jul 18, 2023
2 parents cd20a37 + cd51f71 commit 39eefc3
Show file tree
Hide file tree
Showing 5 changed files with 501 additions and 7 deletions.
92 changes: 90 additions & 2 deletions fiftyone/core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ def apply_model(
num_workers,
skip_failures,
filename_maker,
**kwargs,
)

if batch_size is not None:
Expand Down Expand Up @@ -352,8 +353,25 @@ def _apply_image_model_data_loader(
num_workers,
skip_failures,
filename_maker,
**kwargs,
):
samples = samples.select_fields()
if isinstance(model, SamplesMixin):
model.needs_fields = kwargs
extra_fields = model.needs_fields
else:
extra_fields = None

samples = samples.select_fields(extra_fields)
# we need to make sure classes are available in the samples
if extra_fields:
class_labels = samples.values(
f"{model.needs_fields[0]}.{model.field_type}.label"
)
class_labels = list(
set(one for labels in class_labels for one in labels)
)
samples.classes.update({model.needs_fields[0]: class_labels})

samples_loader = fou.iter_batches(samples, batch_size)
data_loader = _make_data_loader(
samples, model, batch_size, num_workers, skip_failures
Expand All @@ -365,7 +383,10 @@ def _apply_image_model_data_loader(
if isinstance(imgs, Exception):
raise imgs

labels_batch = model.predict_all(imgs)
if extra_fields is not None:
labels_batch = model.predict_all(imgs, sample_batch[0])
else:
labels_batch = model.predict_all(imgs)

for sample, labels in zip(sample_batch, labels_batch):
if filename_maker is not None:
Expand Down Expand Up @@ -2006,6 +2027,73 @@ def embed_prompts(self, args):
return np.stack([self.embed_prompt(arg) for arg in args])


class SamplesMixin(object):
"""Mixin for :class:`Model` classes that need extra sample fields for prediction."""

@property
def needs_fields(self):
"""Which sample fields are required for this model to run."""
if self.keypoint_field is not None:
return [self.keypoint_field]
elif self.detection_field is not None:
return [self.detection_field]
else:
return None

@property
def field_type(self):
"""Which sample fields are required for this model to run."""
if self.keypoint_field is not None:
return "keypoints"
elif self.detection_field is not None:
return "detections"
else:
return None

@needs_fields.setter
def needs_fields(self, kwargs):
"""Find whether points or boxes are requested and set required fields as one of them if any.
Args:
kwargs: keyword arguments that may contain fields to extract
"""
self.keypoint_field = kwargs.get("points_from", None)
self.detection_field = kwargs.get("boxes_from", None)

@property
def keypoint_field(self):
return getattr(self, "_keypoint_field")

@property
def detection_field(self):
return getattr(self, "_detection_field")

@keypoint_field.setter
def keypoint_field(self, field_name):
self._keypoint_field = field_name

@detection_field.setter
def detection_field(self, field_name):
self._detection_field = field_name

def get_labels(self, samples):
field_name = self._get_field_name(samples)
return samples.get_field(field_name)

def get_classes(self, samples):
field_name = self._get_field_name(samples)
return samples.dataset.get_classes(field_name)

def _get_field_name(self, samples):
if not self.needs_fields:
raise ValueError("Set needs_fields before calling get_labels")
field_name = self.needs_fields[0]
if not samples.has_field(field_name):
raise ValueError(f"Samples must have {field_name}.label field")
return field_name


class TorchModelMixin(object):
"""Mixin for :class:`Model` classes that support feeding data for inference
via a :class:`torch:torch.utils.data.DataLoader`.
Expand Down
240 changes: 240 additions & 0 deletions fiftyone/utils/sam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
"""Segment-anything model integration.
"""
import eta.core.utils as etau
import numpy as np
import torch

import fiftyone.core.utils as fou
import fiftyone.core.models as fom
import fiftyone.utils.torch as fout
import fiftyone.zoo.models as fozm

fou.ensure_torch()

from segment_anything import SamAutomaticMaskGenerator, SamPredictor


class SegmentAnythingModelConfig(fout.TorchImageModelConfig, fozm.HasZooModel):
"""Configuration for running a :class:`SegmentAnythingModel`.
See :class:`fiftyone.utils.torch.TorchImageModelConfig` for additional
arguments.
Args:
amg_kwargs: a dictionary of keyword arguments to pass to
``SamAutomaticMaskGenerator``
points_mask_index: an optional index of the mask to use for each Keypoint output
"""

def __init__(self, d):
d = self.init(d)
super().__init__(d)
self.amg_kwargs = self.parse_dict(d, "amg_kwargs", default={})
self.points_mask_index = self.parse_int(
d, "points_mask_index", default=None
)
if self.points_mask_index and not 0 <= self.points_mask_index <= 2:
raise ValueError("mask_index must be 0, 1, or 2")


class SegmentAnythingModel(fout.TorchImageModel, fom.SamplesMixin):
"""Wrapper for running 'segment-anything-model' from https://segment-anything.com/."""

def _download_model(self, config):
config.download_model_if_necessary()

def _load_network(self, config):
entrypoint = etau.get_function(config.entrypoint_fcn)
model = entrypoint(checkpoint=config.model_path)
self.preprocess = False
return model

@staticmethod
def _to_numpy_input(tensor):
"""Converts a float32 torch tensor to a uint8 numpy array.
Args:
tensor: a float32 torch tensor
Returns:
a uint8 numpy array
"""
return (tensor.cpu().numpy() * 255).astype("uint8").transpose(1, 2, 0)

@staticmethod
def _to_torch_output(model_output):
"""Convert SAM's automatic mask output to a single mask torch tensor.
Args:
model_output: a list of masks from SAM's automatic mask generator
Returns:
a torch tensor of shape (num_masks, height, width)
"""
masks = [one["segmentation"].astype(int) for one in model_output]
masks.insert(
0, np.zeros_like(model_output[0]["segmentation"])
) # background
full_mask = np.stack(masks)
return torch.from_numpy(full_mask)

def _forward_pass(self, inputs):
mode = self.field_type
if mode == "keypoints":
return self._forward_pass_points(inputs)
elif mode == "detections":
return self._forward_pass_boxes(inputs)
else:
return self._forward_pass_amg(inputs)

def _forward_pass_amg(self, inputs):
mask_generator = SamAutomaticMaskGenerator(
self._model,
**self.config.amg_kwargs,
)
masks = [
mask_generator.generate(
self._to_numpy_input(inp),
)
for inp in inputs
]
masks = torch.stack([self._to_torch_output(m) for m in masks])
return dict(out=masks)

def _forward_pass_points(self, inputs):
# we will change to instance segmentations
self._output_processor = fout.InstanceSegmenterOutputProcessor(
self.class_labels
)

sam_predictor = SamPredictor(self._model)
outputs = []
for inp, keypoints in zip(inputs, self.prompts):
sam_predictor.set_image(self._to_numpy_input(inp))
h, w = inp.size(1), inp.size(2)

boxes, labels, scores, masks = [], [], [], []

# each keypoints object will generate its own instance segmentation
for kp in keypoints.keypoints:
sam_points, sam_labels = generate_sam_points(kp.points, w, h)

multi_mask, mask_scores, _ = sam_predictor.predict(
point_coords=sam_points,
point_labels=sam_labels,
multimask_output=True,
)
mask_index = (
self.config.points_mask_index
if self.config.points_mask_index
else np.argmax(mask_scores)
)
mask = multi_mask[mask_index].astype(int)

# add boxes, labels, scores, and masks
if mask.any():
boxes.append(_mask_to_box(mask))
labels.append(self.class_labels.index(kp.label))
scores.append(min(1.0, np.max(mask_scores)))
masks.append(mask)

outputs.append(
{
"boxes": torch.tensor(boxes, device=sam_predictor.device),
"labels": torch.tensor(
labels, device=sam_predictor.device
),
"scores": torch.tensor(
scores, device=sam_predictor.device
),
"masks": torch.tensor(
np.array(masks), device=sam_predictor.device
).unsqueeze(1),
}
)

return outputs

def _forward_pass_boxes(self, inputs):
# we have to change it to instance segmentations
self._output_processor = fout.InstanceSegmenterOutputProcessor(
self.class_labels
)

sam_predictor = SamPredictor(self._model)
outputs = []
for inp, detections in zip(inputs, self.prompts):
sam_predictor.set_image(self._to_numpy_input(inp))
h, w = inp.size(1), inp.size(2)
boxes = [d.bounding_box for d in detections.detections]
sam_boxes = np.array([fo_to_sam(box, w, h) for box in boxes])
input_boxes = torch.tensor(sam_boxes, device=sam_predictor.device)
transformed_boxes = sam_predictor.transform.apply_boxes_torch(
input_boxes, (h, w)
)

mask, _, _ = sam_predictor.predict_torch(
point_coords=None,
point_labels=None,
boxes=transformed_boxes,
multimask_output=False,
)
outputs.append(
{
"boxes": input_boxes,
"labels": torch.tensor(
[
self.class_labels.index(d.label)
for d in detections.detections
],
device=sam_predictor.device,
),
"scores": torch.tensor(
[
d.confidence if d.confidence else 1.0
for d in detections.detections
],
device=sam_predictor.device,
),
"masks": mask,
}
)

return outputs

def predict_all(self, imgs, samples=None):
if samples is not None:
self.prompts = [
self.get_labels(samples)
] # tolist because ragged_batches=True
self.class_labels = self.get_classes(samples)

return self._predict_all(imgs)


def generate_sam_points(points, w, h, negative=False):
# Written by Jacob Marks, modified by me
scaled_points = np.array(points) * np.array([w, h])
labels = np.zeros(len(points)) if negative else np.ones(len(points))
return scaled_points, labels


def fo_to_sam(box, img_width, img_height):
# Written by Jacob Marks
new_box = np.copy(np.array(box))
new_box[0] *= img_width
new_box[2] *= img_width
new_box[1] *= img_height
new_box[3] *= img_height
new_box[2] += new_box[0]
new_box[3] += new_box[1]
return np.round(new_box).astype(int)


def _mask_to_box(mask):
pos_indices = np.where(mask)
minx = np.min(pos_indices[1])
maxx = np.max(pos_indices[1])
miny = np.min(pos_indices[0])
maxy = np.max(pos_indices[0])
return [minx, miny, maxx, maxy]
5 changes: 4 additions & 1 deletion fiftyone/utils/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ def _predict_all(self, imgs):
if self._using_half_precision:
imgs = imgs.half()

output = self._model(imgs)
output = self._forward_pass(imgs)

if self.has_logits:
self._output_processor.store_logits = self.store_logits
Expand All @@ -413,6 +413,9 @@ def _predict_all(self, imgs):
output, frame_size, confidence_thresh=self.config.confidence_thresh
)

def _forward_pass(self, inputs):
return self._model(inputs)

def _parse_classes(self, config):
if config.labels_string is not None:
return config.labels_string.split(",")
Expand Down
Loading

0 comments on commit 39eefc3

Please sign in to comment.