From e9e59540bf7ce95848572b26a5007d3face39b38 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=9A=D0=BE=D0=BB=D0=B5=D1=81=D0=BD=D0=B8=D0=BA=D0=BE?= =?UTF-8?q?=D0=B2=20=D0=94=D0=BC=D0=B8=D1=82=D1=80=D0=B8=D0=B9=20=D0=90?= =?UTF-8?q?=D0=BD=D0=B4=D1=80=D0=B5=D0=B5=D0=B2=D0=B8=D1=87?= Date: Mon, 1 Jul 2024 13:07:44 +0300 Subject: [PATCH 1/4] new batch infer --- .../nodes/MakeCropsDetectThem.py | 88 ++++++++++++++++++- 1 file changed, 86 insertions(+), 2 deletions(-) diff --git a/patched_yolo_infer/nodes/MakeCropsDetectThem.py b/patched_yolo_infer/nodes/MakeCropsDetectThem.py index de96762..d0a1aa5 100644 --- a/patched_yolo_infer/nodes/MakeCropsDetectThem.py +++ b/patched_yolo_infer/nodes/MakeCropsDetectThem.py @@ -70,6 +70,7 @@ def __init__( model=None, memory_optimize=True, inference_extra_args=None, + batch_inference=False, ) -> None: if model is None: self.model = YOLO(model_path) # Load the model from the specified path @@ -91,6 +92,7 @@ def __init__( self.memory_optimize = memory_optimize # memory opimization option for segmentation self.class_names_dict = self.model.names # dict with human-readable class names self.inference_extra_args = inference_extra_args # dict with extra ultralytics inference parameters + self.batch_inference = batch_inference self.crops = self.get_crops_xy( self.image, @@ -100,7 +102,10 @@ def __init__( overlap_y=self.overlap_y, show=self.show_crops, ) - self._detect_objects() + if self.batch_inference: + self._detect_objects_batch() + else: + self._detect_objects() def get_crops_xy( self, @@ -141,6 +146,7 @@ def get_crops_xy( x_new = round((x_steps-1) * (shape_x * cross_koef_x) + shape_x) image_innitial = image_full.copy() image_full = cv2.resize(image_full, (x_new, y_new)) + batch_of_crops = [] if show: plt.figure(figsize=[x_steps*0.9, y_steps*0.9]) @@ -176,12 +182,17 @@ def get_crops_xy( x_start=x_start, y_start=y_start, )) + if self.batch_inference: + batch_of_crops.append(im_temp) if show: plt.show() print('Number of generated images:', count) - return data_all_crops + if self.batch_inference: + return data_all_crops, batch_of_crops + else: + return data_all_crops def _detect_objects(self): """ @@ -207,3 +218,76 @@ def _detect_objects(self): crop.calculate_real_values() if self.resize_initial_size: crop.resize_results() + + def _detect_objects_batch(self): + """ + Method to detect objects in batch of crop. + + This method performs batch inference using the YOLO model, + calculates real values, and optionally resizes the results. + + Returns: + None + """ + crops, batch = self.crops + self.crops = crops + self._calculate_batch_inference( + batch, + self.crops, + self.model, + imgsz=self.imgsz, + conf=self.conf, + iou=self.iou, + segment=self.segment, + classes_list=self.classes_list, + memory_optimize=self.memory_optimize, + extra_args=self.inference_extra_args + ) + for crop in self.crops: + crop.calculate_real_values() + if self.resize_initial_size: + crop.resize_results() + + def _calculate_batch_inference( + self, + batch, + crops, + model, + imgsz=640, + conf=0.35, + iou=0.7, + segment=False, + classes_list=None, + memory_optimize=False, + extra_args=None, + ): + # Perform inference + extra_args = {} if extra_args is None else extra_args + predictions = model.predict( + batch, + imgsz=imgsz, + conf=conf, + iou=iou, + classes=classes_list, + verbose=False, + **extra_args + ) + + for pred, crop in zip(predictions, crops): + + # Get the bounding boxes and convert them to a list of lists + crop.detected_xyxy = pred.boxes.xyxy.cpu().int().tolist() + + # Get the classes and convert them to a list + crop.detected_cls = pred.boxes.cls.cpu().int().tolist() + + # Get the mask confidence scores + crop.detected_conf = pred.boxes.conf.cpu().numpy() + + if segment and len(crop.detected_cls) != 0: + if memory_optimize: + # Get the polygons + crop.polygons = [mask.astype(np.uint16) for mask in pred.masks.xy] + else: + # Get the masks + crop.detected_masks = pred.masks.data.cpu().numpy() From 97c13a4c87a416d1aaeff01ad450b596da5739e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=9A=D0=BE=D0=BB=D0=B5=D1=81=D0=BD=D0=B8=D0=BA=D0=BE?= =?UTF-8?q?=D0=B2=20=D0=94=D0=BC=D0=B8=D1=82=D1=80=D0=B8=D0=B9=20=D0=90?= =?UTF-8?q?=D0=BD=D0=B4=D1=80=D0=B5=D0=B5=D0=B2=D0=B8=D1=87?= Date: Mon, 1 Jul 2024 14:22:55 +0300 Subject: [PATCH 2/4] fix info about batch inf --- patched_yolo_infer/nodes/MakeCropsDetectThem.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/patched_yolo_infer/nodes/MakeCropsDetectThem.py b/patched_yolo_infer/nodes/MakeCropsDetectThem.py index d0a1aa5..49ec91f 100644 --- a/patched_yolo_infer/nodes/MakeCropsDetectThem.py +++ b/patched_yolo_infer/nodes/MakeCropsDetectThem.py @@ -50,6 +50,8 @@ class MakeCropsDetectThem: image size (ps: slow operation). class_names_dict (dict): Dictionary containing class names of the YOLO model. memory_optimize (bool): Memory optimization option for segmentation (less accurate results) + batch_inference (bool): Batch inference of image crops through a neural network instead of + sequential passes of crops (ps: Faster inference, higher memory use) inference_extra_args (dict): Dictionary with extra ultralytics inference parameters """ def __init__( @@ -92,7 +94,7 @@ def __init__( self.memory_optimize = memory_optimize # memory opimization option for segmentation self.class_names_dict = self.model.names # dict with human-readable class names self.inference_extra_args = inference_extra_args # dict with extra ultralytics inference parameters - self.batch_inference = batch_inference + self.batch_inference = batch_inference # batch inference of image crops through a neural network self.crops = self.get_crops_xy( self.image, @@ -221,7 +223,7 @@ def _detect_objects(self): def _detect_objects_batch(self): """ - Method to detect objects in batch of crop. + Method to detect objects in batch of image crops. This method performs batch inference using the YOLO model, calculates real values, and optionally resizes the results. @@ -261,7 +263,7 @@ def _calculate_batch_inference( memory_optimize=False, extra_args=None, ): - # Perform inference + # Perform batch inference of image crops through a neural network extra_args = {} if extra_args is None else extra_args predictions = model.predict( batch, From 339abaf9dd43fc7781be328b9f3f7b6670382c43 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=9A=D0=BE=D0=BB=D0=B5=D1=81=D0=BD=D0=B8=D0=BA=D0=BE?= =?UTF-8?q?=D0=B2=20=D0=94=D0=BC=D0=B8=D1=82=D1=80=D0=B8=D0=B9=20=D0=90?= =?UTF-8?q?=D0=BD=D0=B4=D1=80=D0=B5=D0=B5=D0=B2=D0=B8=D1=87?= Date: Mon, 1 Jul 2024 15:09:35 +0300 Subject: [PATCH 3/4] batch processing v1.2.7 --- README.md | 1 + patched_yolo_infer/README.md | 3 ++- requirements.txt | 2 +- setup.py | 10 ++++++---- 4 files changed, 10 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 8dc56bf..ceb2f1c 100644 --- a/README.md +++ b/README.md @@ -121,6 +121,7 @@ Class implementing cropping and passing crops through a neural network for detec | resize_initial_size | bool | False | Whether to resize the results to the original input image size (ps: slow operation). | | memory_optimize | bool | True | Memory optimization option for segmentation (less accurate results when enabled). | | inference_extra_args | dict | None | Dictionary with extra ultralytics [inference parameters](https://docs.ultralytics.com/modes/predict/#inference-arguments) (possible keys: half, device, max_det, augment, agnostic_nms and retina_masks) | +| batch_inference | bool | False | Batch inference of image crops through a neural network instead of sequential passes of crops (ps: faster inference, higher gpu memory use). | **CombineDetections** diff --git a/patched_yolo_infer/README.md b/patched_yolo_infer/README.md index 34bfb59..affa778 100644 --- a/patched_yolo_infer/README.md +++ b/patched_yolo_infer/README.md @@ -9,7 +9,7 @@ This library facilitates various visualizations of inference results from ultral You can install the library via pip: ```bash -pip install patched_yolo_infer +pip install patched-yolo-infer ``` Note: If CUDA support is available, it's recommended to pre-install PyTorch with CUDA support before installing the library. Otherwise, the CPU version will be installed by default. @@ -99,6 +99,7 @@ Class implementing cropping and passing crops through a neural network for detec - **resize_initial_size** (*bool*): Whether to resize the results to the original image size (ps: slow operation). - **memory_optimize** (*bool*): Memory optimization option for segmentation (less accurate results when enabled). - **inference_extra_args** (*dict*): Dictionary with extra ultralytics [inference parameters](https://docs.ultralytics.com/modes/predict/#inference-arguments) (possible keys: half, device, max_det, augment, agnostic_nms and retina_masks) +- **batch_inference** (*bool*): Batch inference of image crops through a neural network instead of sequential passes of crops (ps: faster inference, higher gpu memory use) **CombineDetections** Class implementing combining masks/boxes from multiple crops + NMS (Non-Maximum Suppression).\ diff --git a/requirements.txt b/requirements.txt index cc46d44..15db54c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ +numpy<2.0 torch -numpy opencv-python matplotlib ultralytics \ No newline at end of file diff --git a/setup.py b/setup.py index 9074b47..06d7c5a 100644 --- a/setup.py +++ b/setup.py @@ -8,8 +8,8 @@ long_description = "\n" + fh.read() -VERSION = '1.2.6' -DESCRIPTION = '''YOLO-Patch-Based-Inference for detection/segmentation of small objects in images.''' +VERSION = '1.2.7' +DESCRIPTION = '''Patch-Based-Inference for detection/segmentation of small objects in images.''' setup( name="patched_yolo_infer", @@ -23,7 +23,7 @@ packages=find_packages(), python_requires=">=3.8", install_requires=[ - 'numpy', + 'numpy<2.0', 'opencv-python', 'matplotlib', 'torch', @@ -33,8 +33,10 @@ "python", "yolov8", "yolov9", + "yolov10", "rtdetr", - "sam", + "fastsam", + "sahi", "object detection", "instance segmentation", "patch-based inference", From 422062958f97ab11c795ee445377a01c184ab00e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=9A=D0=BE=D0=BB=D0=B5=D1=81=D0=BD=D0=B8=D0=BA=D0=BE?= =?UTF-8?q?=D0=B2=20=D0=94=D0=BC=D0=B8=D1=82=D1=80=D0=B8=D0=B9=20=D0=90?= =?UTF-8?q?=D0=BD=D0=B4=D1=80=D0=B5=D0=B5=D0=B2=D0=B8=D1=87?= Date: Mon, 1 Jul 2024 21:23:24 +0300 Subject: [PATCH 4/4] TensorRT support --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 9d995d2..c937cfc 100644 --- a/.gitignore +++ b/.gitignore @@ -19,4 +19,5 @@ setup.cfg build info_how_pip_upload.txt examples/patched_yolo_infer +**.engine **.ipynb \ No newline at end of file