Skip to content

Commit

Permalink
Merge pull request #571 from DagsHub/generic-annotations
Browse files Browse the repository at this point in the history
Created generic prediction + annotation mechanism
  • Loading branch information
kbolashev authored Jan 7, 2025
2 parents e3418d7 + c51c16e commit dea8ba0
Showing 1 changed file with 87 additions and 41 deletions.
128 changes: 87 additions & 41 deletions dagshub/data_engine/model/query_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,13 @@

logger = logging.getLogger(__name__)

CustomPredictor = Callable[
[
List[str],
],
List[Tuple[Any, Optional[float]]],
]


class VisualizeError(Exception):
""":meta private:"""
Expand Down Expand Up @@ -571,21 +578,6 @@ def predict_with_mlflow_model(
Default batch size is 1, but it is still being sent as a list for consistency.
log_to_field: If set, writes prediction results to this metadata field in the datasource.
"""

# to support depedency-free dataloading, `Batcher` is a barebones dataloader that sets up batched inference
class Batcher:
def __init__(self, dset, batch_size):
self.dset = dset
self.batch_size = batch_size

def __iter__(self):
self.curr_idx = 0
return self

def __next__(self):
self.curr_idx += self.batch_size
return [self.dset[idx] for idx in range(self.curr_idx - self.batch_size, self.curr_idx)]

if not host:
host = self.datasource.source.repoApi.host

Expand All @@ -610,32 +602,7 @@ def __next__(self):
if "torch" in loader_module:
model.predict = model.__call__

dset = DagsHubDataset(self, tensorizers=[lambda x: x])

predictions = {}
progress = get_rich_progress(rich.progress.MofNCompleteColumn())
task = progress.add_task("Running inference...", total=len(dset))
with progress:
for idx, local_paths in enumerate(
Batcher(dset, batch_size) if batch_size != 1 else dset
): # encapsulates dataset with batcher if necessary and iterates over it
for prediction, remote_path in zip(
post_hook(model.predict(pre_hook(local_paths))),
[result.path for result in self[idx * batch_size : (idx + 1) * batch_size]],
):
predictions[remote_path] = {
"data": {"image": multi_urljoin(self.datasource.source.root_raw_path, remote_path)},
"annotations": [prediction],
}
progress.update(task, advance=batch_size, refresh=True)

if log_to_field:
with self.datasource.metadata_context() as ctx:
for remote_path in predictions:
ctx.update_metadata(
remote_path, {log_to_field: json.dumps(predictions[remote_path]).encode("utf-8")}
)
return predictions
return self.generate_predictions(lambda x: post_hook(model.predict(pre_hook(x))), batch_size, log_to_field)

def get_annotations(self, **kwargs) -> "QueryResult":
"""
Expand Down Expand Up @@ -877,6 +844,14 @@ def to_voxel51_dataset(self, **kwargs) -> "fo.Dataset":
ds.merge_samples(samples)
return ds

@staticmethod
def _get_predict_dict(predictions, remote_path, log_to_field):
res = {log_to_field: json.dumps(predictions[remote_path][0]).encode("utf-8")}
if len(predictions[remote_path]) == 2:
res[f"{log_to_field}_score"] = predictions[remote_path][1]

return res

def _check_downloaded_dataset_size(self):
download_size_prompt_threshold = 100 * (2**20) # 100 Megabytes
dp_size = self._calculate_datapoint_size()
Expand Down Expand Up @@ -939,6 +914,62 @@ def visualize(self, visualizer: Literal["dagshub", "fiftyone"] = "dagshub", **kw

return sess

def generate_predictions(
self,
predict_fn: CustomPredictor,
batch_size: int = 1,
log_to_field: Optional[str] = None,
) -> Dict[str, Tuple[str, Optional[float]]]:
"""
Sends all the datapoints returned in this QueryResult as prediction targets for
a generic object.
Args:
predict_fn: function that handles batched input and returns predictions with an optional prediction score.
batch_size: (optional, default: 1) number of datapoints to run inference on simultaneously
log_to_field: (optional, default: 'prediction') write prediction results to metadata logged in data engine.
If None, just returns predictions.
(in addition to logging to a field, iff that parameter is set)
"""
dset = DagsHubDataset(self, tensorizers=[lambda x: x])

predictions = {}
progress = get_rich_progress(rich.progress.MofNCompleteColumn())
task = progress.add_task("Running inference...", total=len(dset))
with progress:
for idx, local_paths in enumerate(
_Batcher(dset, batch_size) if batch_size != 1 else dset
): # encapsulates dataset with batcher if necessary and iterates over it
for prediction, remote_path in zip(
predict_fn(local_paths),
[result.path for result in self[idx * batch_size : (idx + 1) * batch_size]],
):
predictions[remote_path] = prediction
progress.update(task, advance=batch_size, refresh=True)

if log_to_field:
with self.datasource.metadata_context() as ctx:
for remote_path in predictions:
ctx.update_metadata(remote_path, self._get_predict_dict(predictions, remote_path, log_to_field))
return predictions

def generate_annotations(self, predict_fn: CustomPredictor, batch_size: int = 1, log_to_field: str = "annotation"):
"""
Sends all the datapoints returned in this QueryResult as prediction targets for
a generic object.
Args:
predict_fn: function that handles batched input and returns predictions with an optional prediction score.
batch_size: (optional, default: 1) number of datapoints to run inference on simultaneously.
log_to_field: (optional, default: 'prediction') write prediction results to metadata logged in data engine.
"""
self.generate_predictions(
predict_fn,
batch_size=batch_size,
log_to_field=log_to_field,
)
self.datasource.metadata_field(log_to_field).set_annotation().apply()

def annotate(
self,
open_project: bool = True,
Expand Down Expand Up @@ -1008,3 +1039,18 @@ def log_to_mlflow(self, run: Optional["mlflow.entities.Run"] = None) -> "mlflow.
assert self.query_data_time is not None
artifact_name = self.datasource._get_mlflow_artifact_name("log", self.query_data_time)
return self.datasource._log_to_mlflow(artifact_name, run, self.query_data_time)


# to support depedency-free dataloading, `_Batcher` is a barebones dataloader that sets up batched inference
class _Batcher:
def __init__(self, dset, batch_size):
self.dset = dset
self.batch_size = batch_size

def __iter__(self):
self.curr_idx = 0
return self

def __next__(self):
self.curr_idx += self.batch_size
return [self.dset[idx] for idx in range(self.curr_idx - self.batch_size, self.curr_idx)]

0 comments on commit dea8ba0

Please sign in to comment.