diff --git a/dagshub/data_engine/model/query_result.py b/dagshub/data_engine/model/query_result.py index bd60f57f..061e6c80 100644 --- a/dagshub/data_engine/model/query_result.py +++ b/dagshub/data_engine/model/query_result.py @@ -58,6 +58,13 @@ logger = logging.getLogger(__name__) +CustomPredictor = Callable[ + [ + List[str], + ], + List[Tuple[Any, Optional[float]]], +] + class VisualizeError(Exception): """:meta private:""" @@ -571,21 +578,6 @@ def predict_with_mlflow_model( Default batch size is 1, but it is still being sent as a list for consistency. log_to_field: If set, writes prediction results to this metadata field in the datasource. """ - - # to support depedency-free dataloading, `Batcher` is a barebones dataloader that sets up batched inference - class Batcher: - def __init__(self, dset, batch_size): - self.dset = dset - self.batch_size = batch_size - - def __iter__(self): - self.curr_idx = 0 - return self - - def __next__(self): - self.curr_idx += self.batch_size - return [self.dset[idx] for idx in range(self.curr_idx - self.batch_size, self.curr_idx)] - if not host: host = self.datasource.source.repoApi.host @@ -610,32 +602,7 @@ def __next__(self): if "torch" in loader_module: model.predict = model.__call__ - dset = DagsHubDataset(self, tensorizers=[lambda x: x]) - - predictions = {} - progress = get_rich_progress(rich.progress.MofNCompleteColumn()) - task = progress.add_task("Running inference...", total=len(dset)) - with progress: - for idx, local_paths in enumerate( - Batcher(dset, batch_size) if batch_size != 1 else dset - ): # encapsulates dataset with batcher if necessary and iterates over it - for prediction, remote_path in zip( - post_hook(model.predict(pre_hook(local_paths))), - [result.path for result in self[idx * batch_size : (idx + 1) * batch_size]], - ): - predictions[remote_path] = { - "data": {"image": multi_urljoin(self.datasource.source.root_raw_path, remote_path)}, - "annotations": [prediction], - } - progress.update(task, advance=batch_size, refresh=True) - - if log_to_field: - with self.datasource.metadata_context() as ctx: - for remote_path in predictions: - ctx.update_metadata( - remote_path, {log_to_field: json.dumps(predictions[remote_path]).encode("utf-8")} - ) - return predictions + return self.generate_predictions(lambda x: post_hook(model.predict(pre_hook(x))), batch_size, log_to_field) def get_annotations(self, **kwargs) -> "QueryResult": """ @@ -877,6 +844,14 @@ def to_voxel51_dataset(self, **kwargs) -> "fo.Dataset": ds.merge_samples(samples) return ds + @staticmethod + def _get_predict_dict(predictions, remote_path, log_to_field): + res = {log_to_field: json.dumps(predictions[remote_path][0]).encode("utf-8")} + if len(predictions[remote_path]) == 2: + res[f"{log_to_field}_score"] = predictions[remote_path][1] + + return res + def _check_downloaded_dataset_size(self): download_size_prompt_threshold = 100 * (2**20) # 100 Megabytes dp_size = self._calculate_datapoint_size() @@ -939,6 +914,62 @@ def visualize(self, visualizer: Literal["dagshub", "fiftyone"] = "dagshub", **kw return sess + def generate_predictions( + self, + predict_fn: CustomPredictor, + batch_size: int = 1, + log_to_field: Optional[str] = None, + ) -> Dict[str, Tuple[str, Optional[float]]]: + """ + Sends all the datapoints returned in this QueryResult as prediction targets for + a generic object. + + Args: + predict_fn: function that handles batched input and returns predictions with an optional prediction score. + batch_size: (optional, default: 1) number of datapoints to run inference on simultaneously + log_to_field: (optional, default: 'prediction') write prediction results to metadata logged in data engine. + If None, just returns predictions. + (in addition to logging to a field, iff that parameter is set) + """ + dset = DagsHubDataset(self, tensorizers=[lambda x: x]) + + predictions = {} + progress = get_rich_progress(rich.progress.MofNCompleteColumn()) + task = progress.add_task("Running inference...", total=len(dset)) + with progress: + for idx, local_paths in enumerate( + _Batcher(dset, batch_size) if batch_size != 1 else dset + ): # encapsulates dataset with batcher if necessary and iterates over it + for prediction, remote_path in zip( + predict_fn(local_paths), + [result.path for result in self[idx * batch_size : (idx + 1) * batch_size]], + ): + predictions[remote_path] = prediction + progress.update(task, advance=batch_size, refresh=True) + + if log_to_field: + with self.datasource.metadata_context() as ctx: + for remote_path in predictions: + ctx.update_metadata(remote_path, self._get_predict_dict(predictions, remote_path, log_to_field)) + return predictions + + def generate_annotations(self, predict_fn: CustomPredictor, batch_size: int = 1, log_to_field: str = "annotation"): + """ + Sends all the datapoints returned in this QueryResult as prediction targets for + a generic object. + + Args: + predict_fn: function that handles batched input and returns predictions with an optional prediction score. + batch_size: (optional, default: 1) number of datapoints to run inference on simultaneously. + log_to_field: (optional, default: 'prediction') write prediction results to metadata logged in data engine. + """ + self.generate_predictions( + predict_fn, + batch_size=batch_size, + log_to_field=log_to_field, + ) + self.datasource.metadata_field(log_to_field).set_annotation().apply() + def annotate( self, open_project: bool = True, @@ -1008,3 +1039,18 @@ def log_to_mlflow(self, run: Optional["mlflow.entities.Run"] = None) -> "mlflow. assert self.query_data_time is not None artifact_name = self.datasource._get_mlflow_artifact_name("log", self.query_data_time) return self.datasource._log_to_mlflow(artifact_name, run, self.query_data_time) + + +# to support depedency-free dataloading, `_Batcher` is a barebones dataloader that sets up batched inference +class _Batcher: + def __init__(self, dset, batch_size): + self.dset = dset + self.batch_size = batch_size + + def __iter__(self): + self.curr_idx = 0 + return self + + def __next__(self): + self.curr_idx += self.batch_size + return [self.dset[idx] for idx in range(self.curr_idx - self.batch_size, self.curr_idx)]