Skip to content

Commit

Permalink
Merge pull request #3330 from voxel51/feature/sam
Browse files Browse the repository at this point in the history
Adding Segment Anything to the model zoo!
  • Loading branch information
brimoor authored Jul 25, 2023
2 parents ba3efa9 + e818463 commit 1eafa36
Show file tree
Hide file tree
Showing 6 changed files with 718 additions and 26 deletions.
14 changes: 13 additions & 1 deletion docs/scripts/make_model_zoo_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,23 @@
model = foz.load_zoo_model("{{ name }}")
{% if 'segment-anything' in tags %}
# Segment inside boxes
dataset.apply_model(
model,
label_field="segmentations",
prompt_field="ground_truth", # can contain Detections or Keypoints
)
# Full automatic segmentations
dataset.apply_model(model, label_field="auto")
{% else %}
dataset.apply_model(model, label_field="predictions")
{% endif %}
session = fo.launch_app(dataset)
{% if 'zero-shot' in tags %}
{% if 'clip' in tags %}
#
# Make zero-shot predictions with custom classes
#
Expand Down
143 changes: 131 additions & 12 deletions fiftyone/core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ def apply_model(
"(model.has_logits = %s)" % model.has_logits
)

needs_samples = isinstance(model, SamplesMixin)
use_data_loader = (
isinstance(model, TorchModelMixin) and samples.media_type == fom.IMAGE
)
Expand Down Expand Up @@ -186,9 +187,21 @@ def apply_model(
# pylint: disable=no-member
context.enter_context(fou.SetAttributes(model, preprocess=False))

if needs_samples:
# pylint: disable=no-member
context.enter_context(
fou.SetAttributes(model, needs_fields=kwargs)
)

# pylint: disable=no-member
context.enter_context(model)

if needs_samples:
fields = list(model.needs_fields.values())
samples = samples.select_fields(fields)
else:
samples = samples.select_fields()

if samples.media_type == fom.VIDEO and model.media_type == "video":
return _apply_video_model(
samples,
Expand Down Expand Up @@ -276,13 +289,17 @@ def _apply_image_model_single(
skip_failures,
filename_maker,
):
samples = samples.select_fields()
needs_samples = isinstance(model, SamplesMixin)

with fou.ProgressBar() as pb:
for sample in pb(samples):
try:
img = etai.read(sample.filepath)
labels = model.predict(img)

if needs_samples:
labels = model.predict(img, sample=sample)
else:
labels = model.predict(img)

if filename_maker is not None:
_export_arrays(labels, sample.filepath, filename_maker)
Expand All @@ -309,14 +326,20 @@ def _apply_image_model_batch(
skip_failures,
filename_maker,
):
samples = samples.select_fields()
needs_samples = isinstance(model, SamplesMixin)
samples_loader = fou.iter_batches(samples, batch_size)

with fou.ProgressBar(samples) as pb:
for sample_batch in samples_loader:
try:
imgs = [etai.read(sample.filepath) for sample in sample_batch]
labels_batch = model.predict_all(imgs)

if needs_samples:
labels_batch = model.predict_all(
imgs, samples=sample_batch
)
else:
labels_batch = model.predict_all(imgs)

for sample, labels in zip(sample_batch, labels_batch):
if filename_maker is not None:
Expand Down Expand Up @@ -353,7 +376,7 @@ def _apply_image_model_data_loader(
skip_failures,
filename_maker,
):
samples = samples.select_fields()
needs_samples = isinstance(model, SamplesMixin)
samples_loader = fou.iter_batches(samples, batch_size)
data_loader = _make_data_loader(
samples, model, batch_size, num_workers, skip_failures
Expand All @@ -365,7 +388,12 @@ def _apply_image_model_data_loader(
if isinstance(imgs, Exception):
raise imgs

labels_batch = model.predict_all(imgs)
if needs_samples:
labels_batch = model.predict_all(
imgs, samples=sample_batch
)
else:
labels_batch = model.predict_all(imgs)

for sample, labels in zip(sample_batch, labels_batch):
if filename_maker is not None:
Expand Down Expand Up @@ -400,7 +428,7 @@ def _apply_image_model_to_frames_single(
skip_failures,
filename_maker,
):
samples = samples.select_fields()
needs_samples = isinstance(model, SamplesMixin)
frame_counts, total_frame_count = _get_frame_counts(samples)
is_clips = samples._dataset._is_clips

Expand All @@ -416,7 +444,11 @@ def _apply_image_model_to_frames_single(
sample.filepath, frames=frames
) as video_reader:
for img in video_reader:
labels = model.predict(img)
if needs_samples:
frame = sample.frames[video_reader.frame_number]
labels = model.predict(img, sample=frame)
else:
labels = model.predict(img)

if filename_maker is not None:
_export_arrays(
Expand Down Expand Up @@ -451,7 +483,7 @@ def _apply_image_model_to_frames_batch(
skip_failures,
filename_maker,
):
samples = samples.select_fields()
needs_samples = isinstance(model, SamplesMixin)
frame_counts, total_frame_count = _get_frame_counts(samples)
is_clips = samples._dataset._is_clips

Expand All @@ -467,7 +499,13 @@ def _apply_image_model_to_frames_batch(
sample.filepath, frames=frames
) as video_reader:
for fns, imgs in _iter_batches(video_reader, batch_size):
labels_batch = model.predict_all(imgs)
if needs_samples:
_frames = [sample.frames[fn] for fn in fns]
labels_batch = model.predict_all(
imgs, samples=_frames
)
else:
labels_batch = model.predict_all(imgs)

if filename_maker is not None:
for labels in labels_batch:
Expand Down Expand Up @@ -505,7 +543,7 @@ def _apply_video_model(
skip_failures,
filename_maker,
):
samples = samples.select_fields()
needs_samples = isinstance(model, SamplesMixin)
is_clips = samples._dataset._is_clips

with fou.ProgressBar() as pb:
Expand All @@ -519,7 +557,10 @@ def _apply_video_model(
with etav.FFmpegVideoReader(
sample.filepath, frames=frames
) as video_reader:
labels = model.predict(video_reader)
if needs_samples:
labels = model.predict(video_reader, sample=sample)
else:
labels = model.predict(video_reader)

if filename_maker is not None:
_export_arrays(labels, sample.filepath, filename_maker)
Expand Down Expand Up @@ -2006,6 +2047,84 @@ def embed_prompts(self, args):
return np.stack([self.embed_prompt(arg) for arg in args])


class SamplesMixin(object):
"""Mixin for :class:`Model` classes that need samples for prediction.
Models can implement this mixin to declare that they require one or more
fields of the current sample when performing inference on its media.
The fields are get/set via :meth:`needs_fields`, which is a dict that maps
model-specific keys to sample field names::
model.needs_fields = {"key1": "field1", "key2": "field2", ...}
"""

def __init__(self):
self._fields = {}

@property
def needs_fields(self):
"""A dict mapping model-specific keys to sample field names."""
return self._fields

@needs_fields.setter
def needs_fields(self, fields):
self._fields = fields

def predict(self, arg, sample=None):
"""Peforms prediction on the given data.
Image models should support, at minimum, processing ``arg`` values that
are uint8 numpy arrays (HWC).
Video models should support, at minimum, processing ``arg`` values that
are ``eta.core.video.VideoReader`` instances.
Args:
arg: the data
sample (None): the :class:`fiftyone.core.sample.Sample` associated
with the data
Returns:
a :class:`fiftyone.core.labels.Label` instance or dict of
:class:`fiftyone.core.labels.Label` instances containing the
predictions
"""
raise NotImplementedError("subclasses must implement predict()")

def predict_all(self, args, samples=None):
"""Performs prediction on the given iterable of data.
Image models should support, at minimum, processing ``args`` values
that are either lists of uint8 numpy arrays (HWC) or numpy array
tensors (NHWC).
Video models should support, at minimum, processing ``args`` values
that are lists of ``eta.core.video.VideoReader`` instances.
Subclasses can override this method to increase efficiency, but, by
default, this method simply iterates over the data and applies
:meth:`predict` to each.
Args:
args: an iterable of data
samples (None): an iterable of :class:`fiftyone.core.sample.Sample`
instances associated with the data
Returns:
a list of :class:`fiftyone.core.labels.Label` instances or a list
of dicts of :class:`fiftyone.core.labels.Label` instances
containing the predictions
"""
if samples is None:
return [self.predict(arg) for arg in args]

return [
self.predict(arg, sample=sample)
for arg, sample in zip(args, samples)
]


class TorchModelMixin(object):
"""Mixin for :class:`Model` classes that support feeding data for inference
via a :class:`torch:torch.utils.data.DataLoader`.
Expand Down
Loading

0 comments on commit 1eafa36

Please sign in to comment.