Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Segment Anything to the model zoo! #3330

Merged
merged 19 commits into from
Jul 25, 2023
14 changes: 13 additions & 1 deletion docs/scripts/make_model_zoo_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,23 @@

model = foz.load_zoo_model("{{ name }}")

{% if 'segment-anything' in tags %}
# Segment inside boxes
dataset.apply_model(
model,
label_field="segmentations",
prompt_field="ground_truth", # can contain Detections or Keypoints
)

# Full automatic segmentations
dataset.apply_model(model, label_field="auto")
{% else %}
dataset.apply_model(model, label_field="predictions")
{% endif %}

session = fo.launch_app(dataset)

{% if 'zero-shot' in tags %}
{% if 'clip' in tags %}
#
# Make zero-shot predictions with custom classes
#
Expand Down
143 changes: 131 additions & 12 deletions fiftyone/core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ def apply_model(
"(model.has_logits = %s)" % model.has_logits
)

needs_samples = isinstance(model, SamplesMixin)
use_data_loader = (
isinstance(model, TorchModelMixin) and samples.media_type == fom.IMAGE
)
Expand Down Expand Up @@ -186,9 +187,21 @@ def apply_model(
# pylint: disable=no-member
context.enter_context(fou.SetAttributes(model, preprocess=False))

if needs_samples:
# pylint: disable=no-member
context.enter_context(
fou.SetAttributes(model, needs_fields=kwargs)
)

# pylint: disable=no-member
context.enter_context(model)

if needs_samples:
fields = list(model.needs_fields.values())
samples = samples.select_fields(fields)
else:
samples = samples.select_fields()

if samples.media_type == fom.VIDEO and model.media_type == "video":
return _apply_video_model(
samples,
Expand Down Expand Up @@ -276,13 +289,17 @@ def _apply_image_model_single(
skip_failures,
filename_maker,
):
samples = samples.select_fields()
needs_samples = isinstance(model, SamplesMixin)

with fou.ProgressBar() as pb:
for sample in pb(samples):
try:
img = etai.read(sample.filepath)
labels = model.predict(img)

if needs_samples:
labels = model.predict(img, sample=sample)
else:
labels = model.predict(img)

if filename_maker is not None:
_export_arrays(labels, sample.filepath, filename_maker)
Expand All @@ -309,14 +326,20 @@ def _apply_image_model_batch(
skip_failures,
filename_maker,
):
samples = samples.select_fields()
needs_samples = isinstance(model, SamplesMixin)
samples_loader = fou.iter_batches(samples, batch_size)

with fou.ProgressBar(samples) as pb:
for sample_batch in samples_loader:
try:
imgs = [etai.read(sample.filepath) for sample in sample_batch]
labels_batch = model.predict_all(imgs)

if needs_samples:
labels_batch = model.predict_all(
imgs, samples=sample_batch
)
else:
labels_batch = model.predict_all(imgs)

for sample, labels in zip(sample_batch, labels_batch):
if filename_maker is not None:
Expand Down Expand Up @@ -353,7 +376,7 @@ def _apply_image_model_data_loader(
skip_failures,
filename_maker,
):
samples = samples.select_fields()
needs_samples = isinstance(model, SamplesMixin)
samples_loader = fou.iter_batches(samples, batch_size)
data_loader = _make_data_loader(
samples, model, batch_size, num_workers, skip_failures
Expand All @@ -365,7 +388,12 @@ def _apply_image_model_data_loader(
if isinstance(imgs, Exception):
raise imgs

labels_batch = model.predict_all(imgs)
if needs_samples:
labels_batch = model.predict_all(
imgs, samples=sample_batch
)
else:
labels_batch = model.predict_all(imgs)

for sample, labels in zip(sample_batch, labels_batch):
if filename_maker is not None:
Expand Down Expand Up @@ -400,7 +428,7 @@ def _apply_image_model_to_frames_single(
skip_failures,
filename_maker,
):
samples = samples.select_fields()
needs_samples = isinstance(model, SamplesMixin)
frame_counts, total_frame_count = _get_frame_counts(samples)
is_clips = samples._dataset._is_clips

Expand All @@ -416,7 +444,11 @@ def _apply_image_model_to_frames_single(
sample.filepath, frames=frames
) as video_reader:
for img in video_reader:
labels = model.predict(img)
if needs_samples:
frame = sample.frames[video_reader.frame_number]
labels = model.predict(img, sample=frame)
else:
labels = model.predict(img)

if filename_maker is not None:
_export_arrays(
Expand Down Expand Up @@ -451,7 +483,7 @@ def _apply_image_model_to_frames_batch(
skip_failures,
filename_maker,
):
samples = samples.select_fields()
needs_samples = isinstance(model, SamplesMixin)
frame_counts, total_frame_count = _get_frame_counts(samples)
is_clips = samples._dataset._is_clips

Expand All @@ -467,7 +499,13 @@ def _apply_image_model_to_frames_batch(
sample.filepath, frames=frames
) as video_reader:
for fns, imgs in _iter_batches(video_reader, batch_size):
labels_batch = model.predict_all(imgs)
if needs_samples:
_frames = [sample.frames[fn] for fn in fns]
labels_batch = model.predict_all(
imgs, samples=_frames
)
else:
labels_batch = model.predict_all(imgs)

if filename_maker is not None:
for labels in labels_batch:
Expand Down Expand Up @@ -505,7 +543,7 @@ def _apply_video_model(
skip_failures,
filename_maker,
):
samples = samples.select_fields()
needs_samples = isinstance(model, SamplesMixin)
is_clips = samples._dataset._is_clips

with fou.ProgressBar() as pb:
Expand All @@ -519,7 +557,10 @@ def _apply_video_model(
with etav.FFmpegVideoReader(
sample.filepath, frames=frames
) as video_reader:
labels = model.predict(video_reader)
if needs_samples:
labels = model.predict(video_reader, sample=sample)
else:
labels = model.predict(video_reader)

if filename_maker is not None:
_export_arrays(labels, sample.filepath, filename_maker)
Expand Down Expand Up @@ -2006,6 +2047,84 @@ def embed_prompts(self, args):
return np.stack([self.embed_prompt(arg) for arg in args])


class SamplesMixin(object):
"""Mixin for :class:`Model` classes that need samples for prediction.

Models can implement this mixin to declare that they require one or more
fields of the current sample when performing inference on its media.

The fields are get/set via :meth:`needs_fields`, which is a dict that maps
model-specific keys to sample field names::

model.needs_fields = {"key1": "field1", "key2": "field2", ...}
"""

def __init__(self):
self._fields = {}

@property
def needs_fields(self):
"""A dict mapping model-specific keys to sample field names."""
return self._fields

@needs_fields.setter
def needs_fields(self, fields):
self._fields = fields

def predict(self, arg, sample=None):
"""Peforms prediction on the given data.

Image models should support, at minimum, processing ``arg`` values that
are uint8 numpy arrays (HWC).

Video models should support, at minimum, processing ``arg`` values that
are ``eta.core.video.VideoReader`` instances.

Args:
arg: the data
sample (None): the :class:`fiftyone.core.sample.Sample` associated
with the data

Returns:
a :class:`fiftyone.core.labels.Label` instance or dict of
:class:`fiftyone.core.labels.Label` instances containing the
predictions
"""
raise NotImplementedError("subclasses must implement predict()")

def predict_all(self, args, samples=None):
"""Performs prediction on the given iterable of data.

Image models should support, at minimum, processing ``args`` values
that are either lists of uint8 numpy arrays (HWC) or numpy array
tensors (NHWC).

Video models should support, at minimum, processing ``args`` values
that are lists of ``eta.core.video.VideoReader`` instances.

Subclasses can override this method to increase efficiency, but, by
default, this method simply iterates over the data and applies
:meth:`predict` to each.

Args:
args: an iterable of data
samples (None): an iterable of :class:`fiftyone.core.sample.Sample`
instances associated with the data

Returns:
a list of :class:`fiftyone.core.labels.Label` instances or a list
of dicts of :class:`fiftyone.core.labels.Label` instances
containing the predictions
"""
if samples is None:
return [self.predict(arg) for arg in args]

return [
self.predict(arg, sample=sample)
for arg, sample in zip(args, samples)
]


class TorchModelMixin(object):
"""Mixin for :class:`Model` classes that support feeding data for inference
via a :class:`torch:torch.utils.data.DataLoader`.
Expand Down
Loading