Skip to content

Commit

Permalink
Merge pull request #13 from remydubois/fix/issue_12
Browse files Browse the repository at this point in the history
- Fixed issue #12, by masking scores as well as boxes.
- Added torch and torchvision as proper dev dependencies
- Fixed Pillow version (dev dep) to 9.3.0 in dev dependencies because 9.4.0 does not compile on my mbp (see python-pillow/Pillow#6862)
- Removed deprecated arguments: `cutoff_distance` and `tree`. Removed associated tests.
- Added sanity check to ensure `leaf_size` is strictly positive.
  • Loading branch information
remydubois authored Jan 21, 2023
2 parents 62139e6 + f9c1429 commit 24a3812
Show file tree
Hide file tree
Showing 18 changed files with 581 additions and 372 deletions.
23 changes: 15 additions & 8 deletions .flake8
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
[flake8]
ignore =
E20, # Extra space in brackets
E231,E241, # Multiple spaces around ","
E26, # Comments
E731, # Assigning lambda expression
E741, # Ambiguous variable names
W503, # line break before binary operator
W504, # line break after binary operator
max-line-length = 100
# Extra space in brackets
E20,
# Multiple spaces around ","
E231,E241,
# Comments
E26,
# Assigning lambda expression
E731,
# Ambiguous variable names
E741,
# line break before binary operator
W503,
# line break after binary operator
W504,
max-line-length = 100
40 changes: 40 additions & 0 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python

name: Python package

on:
push:
branches: [ "main" ]
pull_request:
branches: [ "main" ]

jobs:
build:

runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: ["3.7", "3.8", "3.9"]

steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade poetry
python -m pip install flake8
python -m poetry install --with test
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Test with pytest
run: |
python -m poetry run pytest
20 changes: 20 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.3.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/psf/black
rev: "22.12.0"
hooks:
- id: black
- repo: https://github.com/pycqa/isort
rev: v5.11.3
hooks:
- id: isort
name: isort (python)
- repo: https://github.com/pycqa/flake8
rev: "5.0.4"
hooks:
- id: flake8
20 changes: 10 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# LSNMS
Speeding up Non Maximum Suppression with a multiclass support ran on very large images by a several folds factor, using a sparse implementation of NMS.
Speeding up Non Maximum Suppression with a multiclass support ran on very large images by a several folds factor, using a sparse implementation of NMS.
This project becomes useful in the case of very high dimensional images data, when the amount of predicted instances to prune becomes considerable (> 10,000 objects).

<p float="center">
Expand Down Expand Up @@ -52,31 +52,31 @@ pooled_boxes, pooled_scores, cluster_indices = wbc(boxes, scores, iou_threshold=
```
# Description
## Non Maximum Suppression
<!-- Non maximum suppression is an essential step of object detection tasks, aiming at pruning away multiple predictions actually predicting the same instance. This algorithm works greedily by sorting in decreasing order predicted boxes, and step by step, pruning away boxes having an high intersection over union with any other box with higher confidence score (it deletes all the non maximally-scored overlapping boxes). Picture below depicts the overall process: in the left image, several bounding boxes actuually predict the same instance (the model's face). In the right image, NMS was applied to prune away redundant boxes and keep only the highest scoring box.
<!-- Non maximum suppression is an essential step of object detection tasks, aiming at pruning away multiple predictions actually predicting the same instance. This algorithm works greedily by sorting in decreasing order predicted boxes, and step by step, pruning away boxes having an high intersection over union with any other box with higher confidence score (it deletes all the non maximally-scored overlapping boxes). Picture below depicts the overall process: in the left image, several bounding boxes actuually predict the same instance (the model's face). In the right image, NMS was applied to prune away redundant boxes and keep only the highest scoring box.
Note: confidence score are not represented on this image.
<p float="center">
<center><img src="./assets/images/nms_fast_03.jpeg" width="700" />
<figcaption>NMS example (source https://www.pyimagesearch.com/2015/02/16/faster-non-maximum-suppression-python/)</figcaption></center>
</p> -->
A nice introduction of the non maximum suppression algorithm can be found here: https://www.coursera.org/lecture/convolutional-neural-networks/non-max-suppression-dvrjH.
A nice introduction of the non maximum suppression algorithm can be found here: https://www.coursera.org/lecture/convolutional-neural-networks/non-max-suppression-dvrjH.
Basically, NMS discards redundant boxes in a set of predicted instances. It is an essential - and often unavoidable, step of object detection pipelines.


## Scaling up the Non Maximum Suppression process
### Complexity
* In the best case scenario, NMS is a **linear-complex** process (`O(n)`): if all boxes are perfectly overlapping, then one pass of the algorithm discards all the boxes except the highest scoring one.
* In worst case scenario, NMS is a **quadratic-complex** operation (one needs to perform `n * (n - 1) / 2 ` iou comparisons): if all boxes are perfectly disconnected, each NMS step will discard only one box (the highest scoring one, by decreasing order of score). Hence, one needs to perform `(n-1) + (n-2) + ... + 1 = n * (n - 1) / 2 ` iou computations.
* In worst case scenario, NMS is a **quadratic-complex** operation (one needs to perform `n * (n - 1) / 2 ` iou comparisons): if all boxes are perfectly disconnected, each NMS step will discard only one box (the highest scoring one, by decreasing order of score). Hence, one needs to perform `(n-1) + (n-2) + ... + 1 = n * (n - 1) / 2 ` iou computations.
### Working with huge images
When working with high-dimensional images (such as satellital or histology images), one often runs object detection inference by patching (with overlap) the input image and applying NMS to independant patches. Because patches do overlap, a final NMS needs to be re-applied afterward.
In that final case, one is close to be in the worst case scenario since each NMS step will discard only a very low amount of candidate instances (actually, pretty much the amount of overlapping passes over each instance, usually <= 10). Hence, depending on the size of the input image, computation time can reach several minutes on CPU.
When working with high-dimensional images (such as satellital or histology images), one often runs object detection inference by patching (with overlap) the input image and applying NMS to independant patches. Because patches do overlap, a final NMS needs to be re-applied afterward.
In that final case, one is close to be in the worst case scenario since each NMS step will discard only a very low amount of candidate instances (actually, pretty much the amount of overlapping passes over each instance, usually <= 10). Hence, depending on the size of the input image, computation time can reach several minutes on CPU.
A more natural way to speed up NMS could be through parallelization, like it is done for GPU-based implementations, but:
1. Efficiently parallelizing NMS is not a straightforward process
2. If too many instances are predicted, GPU VRAM will often not be sufficient, retaining one from using GPU accelerators
3. The process remains quadratic, and does not scale well.
### LSNMS
This project offers a way to overcome the aforementioned issues elegantly:
1. Before the NMS process, a R-Tree is built on bounding boxes (in a `O(n*log(n))` time)
2. At each NMS step, only boxes overlapping with the current highest scoring box are queried in the tree (in a `O(log(n))` complexity time), and only those neighbors are considered in the pruning process: IoU computation + pruning if necessary. Hence, the overall NMS process is turned from a `O(n**2)` into a `O(n * log(n))` process. See a comparison of run times on the graph below (results obtained on sets of instances whose coordinates vary between 0 and 10,000 (x and y)).
2. At each NMS step, only boxes overlapping with the current highest scoring box are queried in the tree (in a `O(log(n))` complexity time), and only those neighbors are considered in the pruning process: IoU computation + pruning if necessary. Hence, the overall NMS process is turned from a `O(n**2)` into a `O(n * log(n))` process. See a comparison of run times on the graph below (results obtained on sets of instances whose coordinates vary between 0 and 10,000 (x and y)).
A nice introduction of R-Tree can be found here: https://iq.opengenus.org/r-tree/.

Note that the timing reported below are all inclusive: it notably includes the tree building process, otherwise comparison would not be fair.
Expand All @@ -89,13 +89,13 @@ Note that the timing reported below are all inclusive: it notably includes the t

For the sake of speed, this repo is entirely (including the binary tree) built using Numba's just-in-time compilation.

>Concrete example:
>Concrete example:
>Some tests were ran considering ~ 40k x 40k pixels images, and detection inference ran on 512 x 512 overlapping patches (256-strided). Aproximately 300,000 bounding boxes (post patchwise NMS) resulted. Naive NMS ran in approximately 5 minutes on modern CPU, while this implementation ran in 5 seconds, hence offering a close to 60 folds speed up.
### Going further: weighted box clustering
For the sake of completeness, this repo also implements a variant of the Weighted Box Clustering algorithm (from https://arxiv.org/pdf/1811.08661.pdf). Since NMS can artificially push up confidence scores (by selecting only the highest scoring box per instance), WBC overcomes this by averaging box coordinates and scores of all the overlapping boxes (instead of discarding all the non-maximally scored overlaping boxes).

## Disclaimer:
## Disclaimer:
1. The tree implementation could probably be further optimized, see implementation notes below.
2. Much simpler implementation could rely on existing KD-Tree implementations (such as sklearn's), query the tree before NMS, and tweak the NMS process to accept tree query's result. This repo implements it from scratch in full numba for the sake of completeness and elegance.
3. The main parameter deciding the speed up brought by this method is (along with the amount of instances) the **density** of boxes over the image: in other words, the amount of overlapping boxes trimmed at each step of the NMS process. The lower the density of boxes, the higher the speed up factor.
Expand All @@ -109,7 +109,7 @@ As said above, the main parameter guiding speed up from naive NMS is instance (o
---
# Implementations notes
## Tree implementation
Due to Numba compiler's limitations, tree implementations has some specificities:
Due to Numba compiler's limitations, tree implementations has some specificities:
Because jit-class methods can not be recursive, the tree building process (node splitting + children instanciation) can not be entirely done inside the `Node.__init__` method:
* Otherwise, the `__init__` method would be recursive (children instanciation)
* However, methods can call recursive (instance-external) functions: a `build` function is dedicated to this
Expand Down
11 changes: 9 additions & 2 deletions changelog.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
Changelog
=========

Version 0.3.2
------------
- Fixed issue https://github.com/remydubois/lsnms/issues/12, by masking scores as well as boxes.
- Added torch and torchvision as proper dev dependencies
- Fixed Pillow version (dev dep) to 9.3.0 in dev dependencies because 9.4.0 does not compile on my mbp (see https://github.com/python-pillow/Pillow/issues/6862)
- Removed deprecated arguments: `cutoff_distance` and `tree`. Removed associated tests.
- Added sanity check to ensure `leaf_size` is strictly positive.


Version 0.3.1
------------
- Edge case where all box scores are zero (or all below threshold) is now handled (threw uggly error before)
Expand Down Expand Up @@ -33,5 +42,3 @@ Version 0.1.0
- Both BallTree and KDTree are implemented for the sake of exhaustivity
- A cutoff distance needs to be specified to discard boxes to distant one from the other
- Compilation time at first use is quite long: 13 seconds and functions can not be precompiled due to recursivity.


4 changes: 2 additions & 2 deletions lsnms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from lsnms.nms import nms
from lsnms.wbc import wbc
from lsnms.nms import nms # noqa: F401
from lsnms.wbc import wbc # noqa: F401
55 changes: 26 additions & 29 deletions lsnms/nms.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
from typing import Optional
import warnings
from numba import njit

import numpy as np
from lsnms.rtree import RTree, RNode
from lsnms.util import area, intersection, check_correct_input, offset_bboxes, max_spread_axis
from numba import njit

from lsnms.rtree import RNode
from lsnms.util import (
area,
check_correct_input,
intersection,
max_spread_axis,
offset_bboxes,
)


@njit(cache=False)
Expand All @@ -12,21 +19,24 @@ def _nms(
scores: np.array,
iou_threshold: float = 0.5,
score_threshold: float = 0.0,
tree_leaf_size: int = 32,
rtree_leaf_size: int = 32,
) -> np.array:
"""
See `lsnms.nms` docstring.
"""
keep = []

# Discard boxes below score threshold right now to avoid building the tree on useless boxes
boxes = boxes[scores > score_threshold]
# Discard boxes and scores below score threshold right now to avoid building the tree on
# useless boxes
score_mask = scores > score_threshold
boxes = boxes[score_mask]
scores = scores[score_mask]

if len(boxes) == 0:
return np.zeros(0, dtype=np.int64)

# Build the BallTree
rtree = RNode(boxes, tree_leaf_size, max_spread_axis(boxes), None)
# Build the RTree
rtree = RNode(boxes, rtree_leaf_size, max_spread_axis(boxes), None)
rtree.build()

# Compute the areas once and for all: avoid recomputing it at each step
Expand All @@ -48,7 +58,8 @@ def _nms(
boxA = boxes[current_idx]

# Query the overlapping boxes and return their intersection
query, query_intersections = rtree.intersect(boxA, 0.0)
# return only boxes which have at least one pixel of overlap with the box of interest
query, query_intersections = rtree.intersect(boxA, 1.0)

for query_idx, overlap in zip(query, query_intersections):
if not to_consider[query_idx]:
Expand All @@ -69,9 +80,7 @@ def nms(
iou_threshold: float = 0.5,
score_threshold: float = 0.0,
class_ids: Optional[np.array] = None,
cutoff_distance: Optional[int] = None,
tree: Optional[str] = None,
tree_leaf_size: int = 32,
rtree_leaf_size: int = 32,
) -> np.array:
"""
Sparse NMS, will perform Non Maximum Suppression by only comparing overlapping boxes.
Expand All @@ -92,10 +101,10 @@ def nms(
boxes = # array of boxes in format pascal VOC (x0, y0, x1, y1)
scores = # one-dimensional array of confidence scores
class_ids = # one-dimensional array of class indicators (one per object)
keep = nms(boxes, scores, iou_threshold=0.5, score_threshold=0., class_ids=class_ids)
```
Note that this implementation could be further optimized:
- Memory management is quite poor: several back and forth list-to-numpy conversions happen
Expand All @@ -116,26 +125,14 @@ def nms(
One-dimensional integer array indicating the respective classes of the bboxes. If this
is not None, a class-wise NMS will be applied. If None, all boxes are considered of the
same class.
cutoff_distance: int, optional
DEPRECATED, used for compatibility with version 0.1.X.
Since version 0.2.X, it is useless because overlapping boxes are queried using a R-Tree,
which is parameter free.
tree: str, optional
DEPRECATED, used for compatibility with version 0.1.X.
Since version 0.2.X, the tree used is a R-Tree.
tree_leaf_size: int, optional
rtree_leaf_size: int, optional
The leaf size parameter of the underlying R-Tree built for box query.
Returns
-------
np.array
Indices of boxes kept, in decreasing order of confidence score.
"""
if cutoff_distance is not None or tree is not None:
warnings.warn(
"Both `cutoff_distance` and `tree` are deprecated and effect-less from version"
"0.2.X, since R-Tree is used by default to query overlapping boxes."
)

if class_ids is None:
class_ids = np.zeros(len(boxes), dtype=np.int64)
Expand All @@ -154,7 +151,7 @@ def nms(
scores,
iou_threshold=iou_threshold,
score_threshold=score_threshold,
tree_leaf_size=tree_leaf_size,
rtree_leaf_size=rtree_leaf_size,
)

return keep
Expand Down
19 changes: 11 additions & 8 deletions lsnms/rtree.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
import numpy as np
from numba import njit
from collections import OrderedDict
from typing import List

import numpy as np
from numba import boolean, deferred_type, float64, int64, njit, optional
from numba.experimental import jitclass
from numba import deferred_type, optional, int64, float64, boolean

from lsnms.util import (
intersection,
split_along_axis,
box_englobing_boxes,
intersection,
max_spread_axis,
split_along_axis,
)


specs = OrderedDict()
node_type = deferred_type()
specs["data"] = float64[:, :]
Expand Down Expand Up @@ -57,6 +57,7 @@ def __init__(self, data, leaf_size=16, axis=0, indices=None):
self.data = data
self.axis = axis
# Quick sanity checks
assert leaf_size > 0, "Leaf size must be strictly positive"
assert len(data) > 0, "Empty dataset"
assert self.data.shape[-1] % 2 == 0, "odd dimensionality"
assert data.ndim == 2, "Boxes to index should be (n_boxes, 4)"
Expand Down Expand Up @@ -215,10 +216,12 @@ def intersect(
X : np.array
Query box (one box).
indices_buffer : list
List of currently-gathered neighbors. Stores in-place the neighbor indices along the search process
List of currently-gathered neighbors. Stores in-place the neighbor indices along the
search process
intersection_buffer : list
List of currently-gathered neighbor intersection with the query box.
Since the redundancy criterion is intersection over union, I store it here to avoid recomputing it later.
Since the redundancy criterion is intersection over union, I store it here to avoid
recomputing it later.
inter_UB : float, optional
Intersection upper bound: this is the intersection of X with the current node's bbox. By
definition, this is the highest intersection a box contained in this node can get with X.
Expand Down
8 changes: 4 additions & 4 deletions lsnms/util.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from numba import njit, int64, float64
from typing import Optional
from numba.typed import Dict
import math
from typing import Optional

import numpy as np
from numba import njit


@njit(cache=True)
Expand Down Expand Up @@ -77,7 +77,7 @@ def distance_to_hypersphere(X, centroid, radius):
Distance to the sphere.
"""
centroid_dist = rdist(X, centroid)
return max(0, centroid_dist ** 0.5 - radius ** 0.5) ** 2
return max(0, centroid_dist**0.5 - radius**0.5) ** 2


@njit(cache=True)
Expand Down
8 changes: 5 additions & 3 deletions lsnms/wbc.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from numba import njit
import warnings
from typing import Optional

import numpy as np
from lsnms.util import area, check_correct_input
from numba import njit

from lsnms.rtree import RTree
import warnings
from lsnms.util import area, check_correct_input


@njit
Expand Down
Loading

0 comments on commit 24a3812

Please sign in to comment.