diff --git a/client/starwhale/api/_impl/evaluation.py b/client/starwhale/api/_impl/evaluation.py index 5fc61dc228..bb84f66c6c 100644 --- a/client/starwhale/api/_impl/evaluation.py +++ b/client/starwhale/api/_impl/evaluation.py @@ -280,14 +280,14 @@ def _starwhale_internal_run_predict(self) -> None: if not self.dataset_uris: raise FieldTypeOrValueError("context.dataset_uris is empty") join_str = "_#@#_" - cnt = 0 + + received_rows_cnt = 0 # TODO: user custom config batch size, max_retries for uri_str in self.dataset_uris: _uri = Resource(uri_str, typ=ResourceType.dataset) ds = Dataset.dataset(_uri, readonly=True) ds.make_distributed_consumption(session_id=self.context.version) dataset_info = ds.info - cnt = 0 if _uri.instance.is_local: # avoid confusion with underscores in project names idx_prefix = f"{_uri.project.name}/{_uri.name}" @@ -297,6 +297,7 @@ def _starwhale_internal_run_predict(self) -> None: raise KeyError("fetch dataset id error") idx_prefix = str(r_id) for rows in ds.batch_iter(self.predict_batch_size): + received_rows_cnt += len(rows) _start = time.time() _exception = None _results: t.Any = b"" @@ -332,8 +333,13 @@ def _starwhale_internal_run_predict(self) -> None: else: _exception = None + if len(rows) != len(_results): + console.warn( + f"The number of results({len(_results)}) is not equal to the number of rows({len(rows)})" + "maybe batch predict does not return the expected results or ignore some predict exceptions" + ) + for (_idx, _features), _result in zip(rows, _results): - cnt += 1 _idx_with_ds = f"{idx_prefix}{join_str}{_idx}" _duration = time.time() - _start console.debug( @@ -363,7 +369,7 @@ def _starwhale_internal_run_predict(self) -> None: self.evaluation_store.flush_results() console.info( - f"{self.context.step}-{self.context.index} handled {cnt} data items for dataset {self.dataset_uris}" + f"{self.context.step}-{self.context.index} received {received_rows_cnt} data items for dataset {self.dataset_uris}" ) def _update_status(self, status: str) -> None: