Skip to content

Commit

Permalink
frame field rollup (frame-to-sample index)
Browse files Browse the repository at this point in the history
  • Loading branch information
swheaton committed Sep 3, 2024
1 parent 4259541 commit c4a19d3
Show file tree
Hide file tree
Showing 6 changed files with 414 additions and 8 deletions.
20 changes: 18 additions & 2 deletions fiftyone/core/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -1403,6 +1403,8 @@ def get_field_schema(
ftype=None,
embedded_doc_type=None,
read_only=None,
info_keys=None,
created_after=None,
include_private=False,
flat=False,
mode=None,
Expand All @@ -1418,10 +1420,16 @@ def get_field_schema(
iterable of types to which to restrict the returned schema.
Must be subclass(es) of
:class:`fiftyone.core.odm.BaseEmbeddedDocument`
include_private (False): whether to include fields that start with
``_`` in the returned schema
read_only (None): whether to restrict to (True) or exclude (False)
read-only fields. By default, all fields are included
info_keys (None): an optional key or list of keys that must be in
a field's ``info`` dict in order for it to be included in the
returned schema. If ``None``, no filtering is performed.
created_after (None): an optional ``datetime`` to filter the
returned schema by, such that the field was `created_at` after
this time. If ``None``, no filtering is performed.
include_private (False): whether to include fields that start with
``_`` in the returned schema
flat (False): whether to return a flattened schema where all
embedded document fields are included as top-level keys
mode (None): whether to apply the above constraints before and/or
Expand All @@ -1440,6 +1448,8 @@ def get_frame_field_schema(
ftype=None,
embedded_doc_type=None,
read_only=None,
info_keys=None,
created_after=None,
include_private=False,
flat=False,
mode=None,
Expand All @@ -1458,6 +1468,12 @@ def get_frame_field_schema(
:class:`fiftyone.core.odm.BaseEmbeddedDocument`
read_only (None): whether to restrict to (True) or exclude (False)
read-only fields. By default, all fields are included
info_keys (None): an optional key or list of keys that must be in
a field's ``info`` dict in order for it to be included in the
returned schema. If ``None``, no filtering is performed.
created_after (None): an optional ``datetime`` to filter the
returned schema by, such that the field was `created_at` after
this time. If ``None``, no filtering is performed.
include_private (False): whether to include fields that start with
``_`` in the returned schema
flat (False): whether to return a flattened schema where all
Expand Down
286 changes: 284 additions & 2 deletions fiftyone/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1330,6 +1330,8 @@ def get_field_schema(
ftype=None,
embedded_doc_type=None,
read_only=None,
info_keys=None,
created_after=None,
include_private=False,
flat=False,
mode=None,
Expand All @@ -1347,6 +1349,12 @@ def get_field_schema(
:class:`fiftyone.core.odm.BaseEmbeddedDocument`
read_only (None): whether to restrict to (True) or exclude (False)
read-only fields. By default, all fields are included
info_keys (None): an optional key or list of keys that must be in
a field's ``info`` dict in order for it to be included in the
returned schema. If ``None``, no filtering is performed.
created_after (None): an optional ``datetime`` to filter the
returned schema by, such that the field was `created_at` after
this time. If ``None``, no filtering is performed.
include_private (False): whether to include fields that start with
``_`` in the returned schema
flat (False): whether to return a flattened schema where all
Expand All @@ -1364,6 +1372,8 @@ def get_field_schema(
ftype=ftype,
embedded_doc_type=embedded_doc_type,
read_only=read_only,
info_keys=info_keys,
created_after=created_after,
include_private=include_private,
flat=flat,
mode=mode,
Expand All @@ -1374,6 +1384,8 @@ def get_frame_field_schema(
ftype=None,
embedded_doc_type=None,
read_only=None,
info_keys=None,
created_after=None,
include_private=False,
flat=False,
mode=None,
Expand All @@ -1393,6 +1405,12 @@ def get_frame_field_schema(
:class:`fiftyone.core.odm.BaseEmbeddedDocument`
read_only (None): whether to restrict to (True) or exclude (False)
read-only fields. By default, all fields are included
info_keys (None): an optional key or list of keys that must be in
a field's ``info`` dict in order for it to be included in the
returned schema. If ``None``, no filtering is performed.
created_after (None): an optional ``datetime`` to filter the
returned schema by, such that the field was `created_at` after
this time. If ``None``, no filtering is performed.
include_private (False): whether to include fields that start with
``_`` in the returned schema
flat (False): whether to return a flattened schema where all
Expand All @@ -1413,6 +1431,8 @@ def get_frame_field_schema(
ftype=ftype,
embedded_doc_type=embedded_doc_type,
read_only=read_only,
info_keys=info_keys,
created_after=created_after,
include_private=include_private,
flat=flat,
mode=mode,
Expand Down Expand Up @@ -1628,6 +1648,267 @@ def add_frame_field(
if expanded:
self._reload()

def add_frame_rollup_field(
self,
frame_path,
rollup_path=None,
include_counts=False,
overwrite=True,
create_index=True,
):
"""Populates a sample-level field on each sample in the given video
dataset that records the unique values that appear in the frame-level
field across all frames of that sample.
This method is useful for generating a sample-level rollup that can be
efficiently queried to retrieve samples that contain specific values of
interest in at least one frame.
Examples::
import fiftyone as fo
import fiftyone.zoo as foz
dataset = foz.load_zoo_dataset("quickstart-video")
# Populate string lists of observed values
dataset.add_frame_rollup_field(
"frames.detections.detections.label",
"rollup_labels"
)
dataset.add_frame_rollup_field("frames.detections.detections.confidence")
# Populate `Classifications` that record the observed values and their counts
dataset.add_frame_rollup_field(
"frames.detections.detections.label",
"rollup_label_counts",
mode=1
)
dataset.add_frame_rollup_field(
"frames.detections.detections.confidence",
mode=1
)
print(dataset.list_frame_rollup_fields())
Args:
frame_path: the frame field to rollup
rollup_path (None): the sample-level field in which to store the
rollup results. If ``None``, a default field name will be
chosen in this form: rollup_{type}_{embedded_frame_field_name}
include_counts (False): determines the format to populate
``rollup_path`` with:
- ``False``: store a list of unique observed values on each
sample
- ``True``: store the results in a
:class:`fiftyone.core.labels.Classifications` field whose
``label`` and ``count`` attributes encode the observed
values and their counts, respectively. Only applicable when
``frame_path`` contains string values
overwrite (True): whether to overwrite any existing rollup field
with the same name
create_index (True): whether to create a database index for the
rollup field so it can be quickly filtered on
Raises:
ValueError: If ``frame_path`` is not a valid frame field
"""
frame_rollup_key = "_frame_rollup"
has_frames = self._has_frame_fields()
if not has_frames:
raise ValueError("Dataset does not contain videos")

inpath, is_frame_field, list_fields, _, _ = self._parse_field_name(
frame_path
)

if not is_frame_field or inpath not in self.get_frame_field_schema(
flat=True
):
raise ValueError("frame_path must be a valid frame field")

# Format: rollup_[list|counts]_{embedded_frame_field}
if rollup_path is None:
rollup_path = "rollup_"
rollup_path += "counts_" if include_counts else "list_"
rollup_path += inpath.replace(".", "_")

# Delete old field if overwriting
_field = self.get_field(rollup_path)
if _field is not None:
if overwrite and "_frame_rollup" in _field.info:
self.delete_frame_rollup_field(rollup_path)
else:
raise ValueError(f"Field '{rollup_path}' already exists")

# Create sample field up front
info = {frame_rollup_key: inpath}
if include_counts:
self.add_sample_field(
rollup_path,
fo.EmbeddedDocumentField,
embedded_doc_type=fo.Classifications,
info=info,
)
else:
self.add_sample_field(rollup_path, fo.ListField, info=info)

# Optionally create mongodb index on rollup field
if create_index:
self.create_index(rollup_path)

# Create pipeline depending on mode we've chosen
pipeline = [
{"$unwind": "$frames"},
{"$replaceRoot": {"newRoot": "$frames"}},
]

if list_fields:
pipeline.append({"$unwind": "$" + list_fields[0]})

if include_counts:
# Values plus counts:
# 1. Group by (_sample_id, field value) and aggregate the count
# 2. Remove nulls
# 3. Group by _sample_id and push all counts to a list
pipeline.extend(
[
{
"$group": {
"_id": {
"sample": "$_sample_id",
"value": "$" + inpath,
},
"count": {"$sum": 1},
},
},
{"$match": {"$expr": {"$gt": ["$_id.value", None]}}},
{
"$group": {
"_id": "$_id.sample",
"result": {
"$push": {"k": "$_id.value", "v": "$count"}
},
},
},
{"$set": {"result": {"$arrayToObject": "$result"}}},
]
)
else:
# Values only:
# 1. Group by _sample_id, add all values seen to a set
# 2. Remove nulls
pipeline.extend(
[
{
"$group": {
"_id": "$_sample_id",
"values": {"$addToSet": "$" + inpath},
},
},
{
"$project": {
"values": {
"$filter": {
"input": "$values",
"cond": {"$gt": ["$$this", None]},
}
}
}
},
]
)

results = self._aggregate(pipeline=pipeline, attach_frames=True)

# Parse and set values to field
if include_counts:
labels = {
str(r["_id"]): fo.Classifications(
classifications=[
fo.Classification(label=k, count=v)
for k, v in r["result"].items()
]
)
for r in results
}
else:
labels = {str(r["_id"]): list(r["values"]) for r in results}

self.set_values(rollup_path, labels, key_field="id")

# Now lock this field as readonly.
rollup_field = self.get_field(rollup_path)
_set_field_read_only(rollup_field, True)
rollup_field.save()

def delete_frame_rollup_field(self, rollup_path):
"""Deletes frame rollup field
Examples::
import fiftyone as fo
import fiftyone.zoo as foz
dataset = foz.load_zoo_dataset("quickstart-video")
dataset.add_frame_rollup_field(
"frames.detections.detections.label",
"rollup_labels"
)
dataset.delete_frame_rollup_field("rollup_labels")
Args:
rollup_path: path to rollup
Raises:
ValueError: If rollup_path is not a frame rollup field
"""
if rollup_path not in self.list_frame_rollup_fields():
raise ValueError("Path is not a valid frame rollup field")

field = self.get_field(rollup_path)
_set_field_read_only(field, False)
field.save()
self.delete_sample_field(rollup_path)

def check_frame_rollup_fields(self):
"""Returns a list of frame rollup fields that could be out of sync
with their target frame field.
Inclusion in this list is a heuristic, not a guarantee. However, there
are no false negatives.
Returns:
list of frame rollup fields
"""
frames_schema = self.get_frame_field_schema(flat=True)
frame_rollup_schema = self.get_field_schema(
info_keys=["_frame_rollup"]
)

_, last_frame_mod = self.bounds("frames.last_modified_at")

return [
rollup_field.name
for rollup_field in frame_rollup_schema.values()
if (
rollup_field.info["_frame_rollup"] not in frames_schema
or rollup_field.created_at > last_frame_mod
)
]

def list_frame_rollup_fields(self):
"""Lists frame rollup fields created via
:meth:`Dataset.add_frame_rollup_field`
Use :meth:`Dataset.delete_frame_rollup_field` to delete these fields,
or :meth:`Dataset.get_field` to get information about these fields.
"""
return sorted(list(self.get_field_schema(info_keys="_frame_rollup")))

def _add_implied_frame_field(
self, field_name, value, dynamic=False, validate=True
):
Expand Down Expand Up @@ -9614,8 +9895,9 @@ def _handle_nested_fields(schema):

def _set_field_read_only(field_doc, read_only):
field_doc.read_only = read_only
for _field_doc in field_doc.fields:
_set_field_read_only(_field_doc, read_only)
if hasattr(field_doc, "fields"):
for _field_doc in field_doc.fields:
_set_field_read_only(_field_doc, read_only)


def _extract_archive_if_necessary(archive_path, cleanup):
Expand Down
Loading

0 comments on commit c4a19d3

Please sign in to comment.