Skip to content

Commit

Permalink
Merge pull request #11 from LCAS/workshop_6
Browse files Browse the repository at this point in the history
workshop 6 files
  • Loading branch information
gcielniak authored Nov 4, 2024
2 parents 7bf416a + 0717b03 commit 34358ba
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 0 deletions.
13 changes: 13 additions & 0 deletions evaluation/iou_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from rectangle import Rectangle

A = Rectangle(0, 0, 20, 20)
B = Rectangle(10, 10, 20, 20)

C = A & B # intersection rectangle

print(A, A.area)
print(B, B.area)
print(C, C.area)

print(A.iou(B)) # intersection over union

106 changes: 106 additions & 0 deletions evaluation/rectangle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from itertools import product, tee

def pairwise(iterable):
"s -> (s0, s1), (s1, s2), (s2, s3), ..."
a, b = tee(iterable)
next(b, None)
return zip(a, b)

class Rectangle:

__slots__ = '__x1', '__y1', '__x2', '__y2'

def __init__(self, x1, y1, x2, y2):
self.__setstate__((min(x1, x2), min(y1, y2), max(x1, x2), max(y1, y2)))

def __repr__(self):
return '{}({})'.format(type(self).__name__, ', '.join(map(repr, self)))

def __eq__(self, other):
return self.data == other.data

def __ne__(self, other):
return self.data != other.data

def __hash__(self):
return hash(self.data)

def __len__(self):
return 4

def __getitem__(self, key):
return self.data[key]

def __iter__(self):
return iter(self.data)

def __and__(self, other):
x1, y1, x2, y2 = max(self.x1, other.x1), max(self.y1, other.y1), \
min(self.x2, other.x2), min(self.y2, other.y2)
if x1 < x2 and y1 < y2:
return type(self)(x1, y1, x2, y2)

def __sub__(self, other):
intersection = self & other
if intersection is None:
yield self
else:
x, y = {self.x1, self.x2}, {self.y1, self.y2}
if self.x1 < other.x1 < self.x2:
x.add(other.x1)
if self.y1 < other.y1 < self.y2:
y.add(other.y1)
if self.x1 < other.x2 < self.x2:
x.add(other.x2)
if self.y1 < other.y2 < self.y2:
y.add(other.y2)
for (x1, x2), (y1, y2) in product(pairwise(sorted(x)),
pairwise(sorted(y))):
instance = type(self)(x1, y1, x2, y2)
if instance != intersection:
yield instance

def __getstate__(self):
return self.x1, self.y1, self.x2, self.y2

def __setstate__(self, state):
self.__x1, self.__y1, self.__x2, self.__y2 = state

@property
def x1(self):
return self.__x1

@property
def y1(self):
return self.__y1

@property
def x2(self):
return self.__x2

@property
def y2(self):
return self.__y2

@property
def width(self):
return self.x2 - self.x1

@property
def height(self):
return self.y2 - self.y1

@property
def area(self):
return self.width * self.height

def iou(self, other):
inter = self & other
return inter.area / (self.area + other.area - inter.area)

intersection = __and__

difference = __sub__

data = property(__getstate__)

0 comments on commit 34358ba

Please sign in to comment.