diff --git a/benchmark/benchmark.py b/benchmark/benchmark.py index 5ec2849d..8e658144 100644 --- a/benchmark/benchmark.py +++ b/benchmark/benchmark.py @@ -1,10 +1,11 @@ import os import argparse - +# import sys +# sys.path.append('/') import yaml import numpy as np import cv2 as cv - +from detection_metric import detect_metric from models import MODELS from utils import METRICS, DATALOADERS @@ -118,13 +119,17 @@ def printResults(self): # prepend PYTHONPATH to each path prepend_pythonpath(cfg) - # Instantiate benchmarking - benchmark = Benchmark(**cfg['Benchmark']) - # Instantiate model model = build_from_cfg(cfg=cfg['Model'], registery=MODELS, key='name') - - # Run benchmarking print('Benchmarking {}:'.format(model.name)) - benchmark.run(model) - benchmark.printResults() + + type = cfg['Benchmark']['type'] + if type=="Detection metrics": + detect_metric(model,**cfg['Benchmark']) + else: + # Instantiate benchmarking + benchmark = Benchmark(**cfg['Benchmark']) + + # Run benchmarking + benchmark.run(model) + benchmark.printResults() diff --git a/benchmark/cal_rescall/rrc_evaluation_funcs.py b/benchmark/cal_rescall/rrc_evaluation_funcs.py new file mode 100644 index 00000000..656a2ff9 --- /dev/null +++ b/benchmark/cal_rescall/rrc_evaluation_funcs.py @@ -0,0 +1,416 @@ +# encoding: UTF-8 +# https://rrc.cvc.uab.es/?ch=2&com=mymethods&task=1 + +import json +import sys + +sys.path.append('/') +import zipfile +import re +import sys +import os +import codecs +import traceback + + +def print_help(): + sys.stdout.write('Usage: python %s.py -g= -s= [-o= -p=]' % sys.argv[0]) + sys.exit(2) + + +def load_zip_file_keys(file, fileNameRegExp=''): + """ + Returns an array with the entries of the ZIP file that match with the regular expression. + The key's are the names or the file or the capturing group definied in the fileNameRegExp + """ + try: + archive = zipfile.ZipFile(file, mode='r', allowZip64=True) + except: + raise Exception('Error loading the ZIP archive.') + + pairs = [] + + for name in archive.namelist(): + addFile = True + keyName = name + if fileNameRegExp != "": + m = re.match(fileNameRegExp, name) + if m == None: + addFile = False + else: + if len(m.groups()) > 0: + keyName = m.group(1) + + if addFile: + pairs.append(keyName) + + return pairs + + +def load_zip_file(file, fileNameRegExp='', allEntries=False): + """ + Returns an array with the contents (filtered by fileNameRegExp) of a ZIP file. + The key's are the names or the file or the capturing group definied in the fileNameRegExp + allEntries validates that all entries in the ZIP file pass the fileNameRegExp + """ + try: + archive = zipfile.ZipFile(file, mode='r', allowZip64=True) + except: + + raise Exception('Error loading the ZIP archive') + + pairs = [] + for name in archive.namelist(): + addFile = True + keyName = name + if fileNameRegExp != "": + m = re.match(fileNameRegExp, name) + if m == None: + addFile = False + else: + if len(m.groups()) > 0: + keyName = m.group(1) + + if addFile: + pairs.append([keyName, archive.read(name)]) + else: + if allEntries: + raise Exception('ZIP entry not valid: %s' % name) + + return dict(pairs) + + +def load_folder_file(file, fileNameRegExp='', allEntries=False): + """ + Returns an array with the contents (filtered by fileNameRegExp) of a ZIP file. + The key's are the names or the file or the capturing group definied in the fileNameRegExp + allEntries validates that all entries in the ZIP file pass the fileNameRegExp + """ + pairs = [] + for name in os.listdir(file): + addFile = True + keyName = name + if fileNameRegExp != "": + m = re.match(fileNameRegExp, name) + if m == None: + addFile = False + else: + if len(m.groups()) > 0: + keyName = m.group(1) + + if addFile: + pairs.append([keyName, open(os.path.join(file, name)).read()]) + else: + if allEntries: + raise Exception('ZIP entry not valid: %s' % name) + + return dict(pairs) + + +def decode_utf8(raw): + """ + Returns a Unicode object on success, or None on failure + """ + try: + raw = codecs.decode(raw, 'utf-8', 'replace') + # extracts BOM if exists + raw = raw.encode('utf8') + if raw.startswith(codecs.BOM_UTF8): + raw = raw.replace(codecs.BOM_UTF8, '', 1) + return raw.decode('utf-8') + except: + return None + + +def validate_lines_in_file(fileName, file_contents, CRLF=True, LTRB=True, withTranscription=False, withConfidence=False, + imWidth=0, imHeight=0): + """ + This function validates that all lines of the file calling the Line validation function for each line + """ + utf8File = decode_utf8(file_contents) + if (utf8File is None): + raise Exception("The file %s is not UTF-8" % fileName) + + lines = utf8File.split("\r\n" if CRLF else "\n") + for line in lines: + line = line.replace("\r", "").replace("\n", "") + if (line != ""): + try: + validate_tl_line(line, LTRB, withTranscription, withConfidence, imWidth, imHeight) + except Exception as e: + + raise Exception( + ("Line in sample not valid. Sample: %s Line: %s Error: %s" % (fileName, line, str(e))).encode( + 'utf-8', 'replace')) + + +def validate_tl_line(line, LTRB=True, withTranscription=True, withConfidence=True, imWidth=0, imHeight=0): + """ + Validate the format of the line. If the line is not valid an exception will be raised. + If maxWidth and maxHeight are specified, all points must be inside the imgage bounds. + Posible values are: + LTRB=True: xmin,ymin,xmax,ymax[,confidence][,transcription] + LTRB=False: x1,y1,x2,y2,x3,y3,x4,y4[,confidence][,transcription] + """ + get_tl_line_values(line, LTRB, withTranscription, withConfidence, imWidth, imHeight) + + +def get_tl_line_values(line, LTRB=True, withTranscription=False, withConfidence=False, imWidth=0, imHeight=0): + """ + Validate the format of the line. If the line is not valid an exception will be raised. + If maxWidth and maxHeight are specified, all points must be inside the imgage bounds. + Posible values are: + LTRB=True: xmin,ymin,xmax,ymax[,confidence][,transcription] + LTRB=False: x1,y1,x2,y2,x3,y3,x4,y4[,confidence][,transcription] + Returns values from a textline. Points , [Confidences], [Transcriptions] + """ + confidence = 0.0 + transcription = "" + points = [] + + numPoints = 4 + + if LTRB: + + numPoints = 4 + + if withTranscription and withConfidence: + m = re.match( + r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$', line) + if m == None: + m = re.match( + r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$', + line) + + raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,confidence,transcription") + elif withConfidence: + m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*$', + line) + if m == None: + raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,confidence") + elif withTranscription: + m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,(.*)$', line) + if m == None: + raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,transcription") + else: + m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,?\s*$', line) + if m == None: + raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax") + + xmin = int(m.group(1)) + ymin = int(m.group(2)) + xmax = int(m.group(3)) + ymax = int(m.group(4)) + if (xmax < xmin): + raise Exception("Xmax value (%s) not valid (Xmax < Xmin)." % (xmax)) + if (ymax < ymin): + raise Exception("Ymax value (%s) not valid (Ymax < Ymin)." % (ymax)) + + points = [float(m.group(i)) for i in range(1, (numPoints + 1))] + + if (imWidth > 0 and imHeight > 0): + validate_point_inside_bounds(xmin, ymin, imWidth, imHeight) + validate_point_inside_bounds(xmax, ymax, imWidth, imHeight) + + else: + + numPoints = 8 + + if withTranscription and withConfidence: + m = re.match( + r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$', + line) + if m == None: + raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,confidence,transcription") + elif withConfidence: + m = re.match( + r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-1].?[0-9]*)\s*$', + line) + if m == None: + raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,confidence") + elif withTranscription: + m = re.match( + r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,(.*)$', + line) + if m == None: + raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,transcription") + else: + m = re.match( + r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*$', + line) + if m == None: + raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4") + + points = [float(m.group(i)) for i in range(1, (numPoints + 1))] + + validate_clockwise_points(points) + + if (imWidth > 0 and imHeight > 0): + validate_point_inside_bounds(points[0], points[1], imWidth, imHeight) + validate_point_inside_bounds(points[2], points[3], imWidth, imHeight) + validate_point_inside_bounds(points[4], points[5], imWidth, imHeight) + validate_point_inside_bounds(points[6], points[7], imWidth, imHeight) + + if withConfidence: + try: + confidence = float(m.group(numPoints + 1)) + except ValueError: + + raise Exception("Confidence value must be a float") + + if withTranscription: + posTranscription = numPoints + (2 if withConfidence else 1) + transcription = m.group(posTranscription) + m2 = re.match(r'^\s*\"(.*)\"\s*$', transcription) + if m2 != None: # Transcription with double quotes, we extract the value and replace escaped characters + transcription = m2.group(1).replace("\\\\", "\\").replace("\\\"", "\"") + + return points, confidence, transcription + + +def validate_point_inside_bounds(x, y, imWidth, imHeight): + if (x < 0 or x > imWidth): + raise Exception("X value (%s) not valid. Image dimensions: (%s,%s)" % (x, imWidth, imHeight)) + if (y < 0 or y > imHeight): + raise Exception( + "Y value (%s) not valid. Image dimensions: (%s,%s) Sample: %s Line:%s" % (y, imWidth, imHeight)) + + +def validate_clockwise_points(points): + """ + Validates that the points that the 4 points that dlimite a polygon are in clockwise order. + """ + + if len(points) != 8: + raise Exception("Points list not valid." + str(len(points))) + + point = [ + [int(points[0]), int(points[1])], + [int(points[2]), int(points[3])], + [int(points[4]), int(points[5])], + [int(points[6]), int(points[7])] + ] + edge = [ + (point[1][0] - point[0][0]) * (point[1][1] + point[0][1]), + (point[2][0] - point[1][0]) * (point[2][1] + point[1][1]), + (point[3][0] - point[2][0]) * (point[3][1] + point[2][1]), + (point[0][0] - point[3][0]) * (point[0][1] + point[3][1]) + ] + + summatory = edge[0] + edge[1] + edge[2] + edge[3]; + if summatory > 0: + raise Exception( + "Points are not clockwise. The coordinates of bounding quadrilaterals have to be given in clockwise order. Regarding the correct interpretation of 'clockwise' remember that the train_images coordinate system used is the standard one, with the train_images origin at the upper left, the X axis extending to the right and Y axis extending downwards.") + + +def get_tl_line_values_from_file_contents(content, CRLF=True, LTRB=True, withTranscription=False, withConfidence=False, + imWidth=0, imHeight=0, sort_by_confidences=True): + """ + Returns all points, confindences and transcriptions of a file in lists. Valid line formats: + xmin,ymin,xmax,ymax,[confidence],[transcription] + x1,y1,x2,y2,x3,y3,x4,y4,[confidence],[transcription] + """ + pointsList = [] + transcriptionsList = [] + confidencesList = [] + + lines = content.split("\r\n" if CRLF else "\n") + for line in lines: + line = line.replace("\r", "").replace("\n", "") + if (line != ""): + points, confidence, transcription = get_tl_line_values(line, LTRB, withTranscription, withConfidence, + imWidth, imHeight); + pointsList.append(points) + transcriptionsList.append(transcription) + confidencesList.append(confidence) + + if withConfidence and len(confidencesList) > 0 and sort_by_confidences: + import numpy as np + sorted_ind = np.argsort(-np.array(confidencesList)) + confidencesList = [confidencesList[i] for i in sorted_ind] + pointsList = [pointsList[i] for i in sorted_ind] + transcriptionsList = [transcriptionsList[i] for i in sorted_ind] + + return pointsList, confidencesList, transcriptionsList + + +def main_evaluation(p, evalParams, evaluate_method_fn, show_result=True, per_sample=True): + """ + This process validates a method, evaluates it and if it succed generates a ZIP file with a JSON entry for each sample. + Params: + p: Dictionary of parmeters with the GT/submission locations. If None is passed, the parameters send by the system are used. + default_evaluation_params_fn: points to a function that returns a dictionary with the default parameters used for the evaluation + validate_data_fn: points to a method that validates the corrct format of the submission + evaluate_method_fn: points to a function that evaluated the submission and return a Dictionary with the results + """ + if 'p' in p.keys(): + evalParams.update(p['p'] if isinstance(p['p'], dict) else json.loads(p['p'][1:-1])) + + resDict = {'calculated': True, 'Message': '', 'method': '{}', 'per_sample': '{}'} + try: + # validate_data_fn(p['g'], p['s'], evalParams) + evalData = evaluate_method_fn(p['g'], p['s'], evalParams) + resDict.update(evalData) + + except Exception as e: + traceback.print_exc() + resDict['Message'] = str(e) + resDict['calculated'] = False + + if 'o' in p: + if not os.path.exists(p['o']): + os.makedirs(p['o']) + + resultsOutputname = p['o'] + '/results.zip' + outZip = zipfile.ZipFile(resultsOutputname, mode='w', allowZip64=True) + + del resDict['per_sample'] + if 'output_items' in resDict.keys(): + del resDict['output_items'] + + outZip.writestr('method.json', json.dumps(resDict)) + + if not resDict['calculated']: + if show_result: + sys.stderr.write('Error!\n' + resDict['Message'] + '\n\n') + if 'o' in p: + outZip.close() + return resDict + + if 'o' in p: + if per_sample == True: + for k, v in evalData['per_sample'].iteritems(): + outZip.writestr(k + '.json', json.dumps(v)) + + if 'output_items' in evalData.keys(): + for k, v in evalData['output_items'].iteritems(): + outZip.writestr(k, v) + + outZip.close() + + if show_result: + sys.stdout.write("Calculated!") + sys.stdout.write(json.dumps(resDict['method'])) + + return resDict + + +def main_validation(default_evaluation_params_fn, validate_data_fn): + """ + This process validates a method + Params: + default_evaluation_params_fn: points to a function that returns a dictionary with the default parameters used for the evaluation + validate_data_fn: points to a method that validates the corrct format of the submission + """ + try: + p = dict([s[1:].split('=') for s in sys.argv[1:]]) + evalParams = default_evaluation_params_fn() + if 'p' in p.keys(): + evalParams.update(p['p'] if isinstance(p['p'], dict) else json.loads(p['p'][1:-1])) + + validate_data_fn(p['g'], p['s'], evalParams) + print('SUCCESS') + sys.exit(0) + except Exception as e: + print(str(e)) + sys.exit(101) diff --git a/benchmark/cal_rescall/script.py b/benchmark/cal_rescall/script.py new file mode 100644 index 00000000..227a05ec --- /dev/null +++ b/benchmark/cal_rescall/script.py @@ -0,0 +1,283 @@ +# -*- coding: utf-8 -*- +# https://rrc.cvc.uab.es/?ch=2&com=mymethods&task=1 + +from collections import namedtuple +from . import rrc_evaluation_funcs +import Polygon as plg +import numpy as np + + +def evaluate_method(gtFilePath, submFilePath, evaluationParams): + """ + Method evaluate_method: evaluate method and returns the results + Results. Dictionary with the following values: + - method (required) Global method metrics. Ex: { 'Precision':0.8,'Recall':0.9 } + - samples (optional) Per sample metrics. Ex: {'sample1' : { 'Precision':0.8,'Recall':0.9 } , 'sample2' : { 'Precision':0.8,'Recall':0.9 } + """ + + def polygon_from_points(points): + """ + Returns a Polygon object to use with the Polygon2 class from a list of 8 points: x1,y1,x2,y2,x3,y3,x4,y4 + """ + resBoxes = np.empty([1, 8], dtype='int32') + resBoxes[0, 0] = int(points[0]) + resBoxes[0, 4] = int(points[1]) + resBoxes[0, 1] = int(points[2]) + resBoxes[0, 5] = int(points[3]) + resBoxes[0, 2] = int(points[4]) + resBoxes[0, 6] = int(points[5]) + resBoxes[0, 3] = int(points[6]) + resBoxes[0, 7] = int(points[7]) + pointMat = resBoxes[0].reshape([2, 4]).T + return plg.Polygon(pointMat) + + def rectangle_to_polygon(rect): + resBoxes = np.empty([1, 8], dtype='int32') + resBoxes[0, 0] = int(rect.xmin) + resBoxes[0, 4] = int(rect.ymax) + resBoxes[0, 1] = int(rect.xmin) + resBoxes[0, 5] = int(rect.ymin) + resBoxes[0, 2] = int(rect.xmax) + resBoxes[0, 6] = int(rect.ymin) + resBoxes[0, 3] = int(rect.xmax) + resBoxes[0, 7] = int(rect.ymax) + + pointMat = resBoxes[0].reshape([2, 4]).T + + return plg.Polygon(pointMat) + + def rectangle_to_points(rect): + points = [int(rect.xmin), int(rect.ymax), int(rect.xmax), int(rect.ymax), int(rect.xmax), int(rect.ymin), + int(rect.xmin), int(rect.ymin)] + return points + + def get_union(pD, pG): + areaA = pD.area() + areaB = pG.area() + return areaA + areaB - get_intersection(pD, pG) + + def get_intersection_over_union(pD, pG): + try: + return get_intersection(pD, pG) / get_union(pD, pG) + except: + return 0 + + def get_intersection(pD, pG): + pInt = pD & pG + if len(pInt) == 0: + return 0 + return pInt.area() + + def compute_ap(confList, matchList, numGtCare): + correct = 0 + AP = 0 + if len(confList) > 0: + confList = np.array(confList) + matchList = np.array(matchList) + sorted_ind = np.argsort(-confList) + confList = confList[sorted_ind] + matchList = matchList[sorted_ind] + for n in range(len(confList)): + match = matchList[n] + if match: + correct += 1 + AP += float(correct) / (n + 1) + + if numGtCare > 0: + AP /= numGtCare + + return AP + + perSampleMetrics = {} + + matchedSum = 0 + + Rectangle = namedtuple('Rectangle', 'xmin ymin xmax ymax') + + gt = rrc_evaluation_funcs.load_folder_file(gtFilePath, evaluationParams['GT_SAMPLE_NAME_2_ID']) + subm = rrc_evaluation_funcs.load_folder_file(submFilePath, evaluationParams['DET_SAMPLE_NAME_2_ID'], True) + + numGlobalCareGt = 0 + numGlobalCareDet = 0 + + arrGlobalConfidences = [] + arrGlobalMatches = [] + + for resFile in gt: + + gtFile = gt[resFile] # rrc_evaluation_funcs.decode_utf8(gt[resFile]) + recall = 0 + precision = 0 + hmean = 0 + + detMatched = 0 + + iouMat = np.empty([1, 1]) + + gtPols = [] + detPols = [] + + gtPolPoints = [] + detPolPoints = [] + + # Array of Ground Truth Polygons' keys marked as don't Care + gtDontCarePolsNum = [] + # Array of Detected Polygons' matched with a don't Care GT + detDontCarePolsNum = [] + + pairs = [] + detMatchedNums = [] + + arrSampleConfidences = [] + arrSampleMatch = [] + sampleAP = 0 + + evaluationLog = "" + + pointsList, _, transcriptionsList = rrc_evaluation_funcs. \ + get_tl_line_values_from_file_contents(gtFile, + evaluationParams['CRLF'], \ + evaluationParams['LTRB'], \ + True, False) + for n in range(len(pointsList)): + points = pointsList[n] + transcription = transcriptionsList[n] + dontCare = transcription == "###" + if evaluationParams['LTRB']: + gtRect = Rectangle(*points) + gtPol = rectangle_to_polygon(gtRect) + else: + gtPol = polygon_from_points(points) + gtPols.append(gtPol) + gtPolPoints.append(points) + if dontCare: + gtDontCarePolsNum.append(len(gtPols) - 1) + + evaluationLog += "GT polygons: " + str(len(gtPols)) + ( + " (" + str(len(gtDontCarePolsNum)) + " don't care)\n" if len(gtDontCarePolsNum) > 0 else "\n") + + if resFile in subm: + + detFile = subm[resFile] # rrc_evaluation_funcs.decode_utf8(subm[resFile]) + + pointsList, confidencesList, _ = rrc_evaluation_funcs. \ + get_tl_line_values_from_file_contents(detFile, \ + evaluationParams['CRLF'], \ + evaluationParams['LTRB'], \ + False, \ + evaluationParams['CONFIDENCES']) + for n in range(len(pointsList)): + points = pointsList[n] + + if evaluationParams['LTRB']: + detRect = Rectangle(*points) + detPol = rectangle_to_polygon(detRect) + else: + detPol = polygon_from_points(points) + detPols.append(detPol) + detPolPoints.append(points) + if len(gtDontCarePolsNum) > 0: + for dontCarePol in gtDontCarePolsNum: + dontCarePol = gtPols[dontCarePol] + intersected_area = get_intersection(dontCarePol, detPol) + pdDimensions = detPol.area() + precision = 0 if pdDimensions == 0 else intersected_area / pdDimensions + if (precision > evaluationParams['AREA_PRECISION_CONSTRAINT']): + detDontCarePolsNum.append(len(detPols) - 1) + break + + evaluationLog += "DET polygons: " + str(len(detPols)) + ( + " (" + str(len(detDontCarePolsNum)) + " don't care)\n" if len(detDontCarePolsNum) > 0 else "\n") + + if len(gtPols) > 0 and len(detPols) > 0: + # Calculate IoU and precision matrixs + outputShape = [len(gtPols), len(detPols)] + iouMat = np.empty(outputShape) + gtRectMat = np.zeros(len(gtPols), np.int8) + detRectMat = np.zeros(len(detPols), np.int8) + for gtNum in range(len(gtPols)): + for detNum in range(len(detPols)): + pG = gtPols[gtNum] + pD = detPols[detNum] + iouMat[gtNum, detNum] = get_intersection_over_union(pD, pG) + + for gtNum in range(len(gtPols)): + for detNum in range(len(detPols)): + if gtRectMat[gtNum] == 0 and detRectMat[ + detNum] == 0 and gtNum not in gtDontCarePolsNum and detNum not in detDontCarePolsNum: + if iouMat[gtNum, detNum] > evaluationParams['IOU_CONSTRAINT']: + gtRectMat[gtNum] = 1 + detRectMat[detNum] = 1 + detMatched += 1 + pairs.append({'gt': gtNum, 'det': detNum}) + detMatchedNums.append(detNum) + evaluationLog += "Match GT #" + str(gtNum) + " with Det #" + str(detNum) + "\n" + + if evaluationParams['CONFIDENCES']: + for detNum in range(len(detPols)): + if detNum not in detDontCarePolsNum: + # we exclude the don't care detections + match = detNum in detMatchedNums + + arrSampleConfidences.append(confidencesList[detNum]) + arrSampleMatch.append(match) + + arrGlobalConfidences.append(confidencesList[detNum]) + arrGlobalMatches.append(match) + + numGtCare = (len(gtPols) - len(gtDontCarePolsNum)) + numDetCare = (len(detPols) - len(detDontCarePolsNum)) + if numGtCare == 0: + recall = float(1) + precision = float(0) if numDetCare > 0 else float(1) + sampleAP = precision + else: + recall = float(detMatched) / numGtCare + precision = 0 if numDetCare == 0 else float(detMatched) / numDetCare + if evaluationParams['CONFIDENCES'] and evaluationParams['PER_SAMPLE_RESULTS']: + sampleAP = compute_ap(arrSampleConfidences, arrSampleMatch, numGtCare) + + hmean = 0 if (precision + recall) == 0 else 2.0 * precision * recall / (precision + recall) + + matchedSum += detMatched + numGlobalCareGt += numGtCare + numGlobalCareDet += numDetCare + + if evaluationParams['PER_SAMPLE_RESULTS']: + perSampleMetrics[resFile] = { + 'precision': precision, + 'recall': recall, + 'hmean': hmean, + 'pairs': pairs, + 'AP': sampleAP, + 'iouMat': [] if len(detPols) > 100 else iouMat.tolist(), + 'gtPolPoints': gtPolPoints, + 'detPolPoints': detPolPoints, + 'gtDontCare': gtDontCarePolsNum, + 'detDontCare': detDontCarePolsNum, + 'evaluationParams': evaluationParams, + 'evaluationLog': evaluationLog + } + + # Compute MAP and MAR + AP = 0 + if evaluationParams['CONFIDENCES']: + AP = compute_ap(arrGlobalConfidences, arrGlobalMatches, numGlobalCareGt) + + methodRecall = 0 if numGlobalCareGt == 0 else float(matchedSum) / numGlobalCareGt + methodPrecision = 0 if numGlobalCareDet == 0 else float(matchedSum) / numGlobalCareDet + methodHmean = 0 if methodRecall + methodPrecision == 0 else 2 * methodRecall * methodPrecision / ( + methodRecall + methodPrecision) + + methodMetrics = {'precision': methodPrecision, 'recall': methodRecall, 'hmean': methodHmean, 'AP': AP} + + resDict = {'calculated': True, 'Message': '', 'method': methodMetrics, 'per_sample': perSampleMetrics} + + return resDict + + +def cal_recall_precison_f1(gt_path, result_path, default_evaluation_params, show_result=False): + p = {'g': gt_path, 's': result_path} + result = rrc_evaluation_funcs.main_evaluation(p, default_evaluation_params, evaluate_method, + show_result) + return result['method'] diff --git a/benchmark/config/text_detection_db_metrics.yaml b/benchmark/config/text_detection_db_metrics.yaml new file mode 100644 index 00000000..9186a3e1 --- /dev/null +++ b/benchmark/config/text_detection_db_metrics.yaml @@ -0,0 +1,21 @@ +Benchmark: + name: "Text Detection metrics Benchmark" + type: "Detection metrics" + data: + gt: "./data/icdar2015/gt" + imgs: "./data/icdar2015/img" + out: "./data/icdar2015/out" + iou_constraint: 0.3 + area_precision_constraint: 0.8 + height: 736 + width: 736 + backend: "default" + target: "cpu" + +Model: + name: "DB" + modelPath: "models/text_detection_db/text_detection_DB_TD500_resnet18_2021sep.onnx" + binaryThreshold: 0.3 + polygonThreshold: 0.5 + maxCandidates: 200 + unclipRatio: 2.0 \ No newline at end of file diff --git a/benchmark/detection_metric.py b/benchmark/detection_metric.py new file mode 100644 index 00000000..5ed9c47f --- /dev/null +++ b/benchmark/detection_metric.py @@ -0,0 +1,87 @@ +# This file is part of OpenCV Zoo project. +# It is subject to the license terms in the LICENSE file found in the same directory. +# + +import argparse +import glob +import time + +import numpy as np +import cv2 as cv +from cal_rescall.script import cal_recall_precison_f1 + + + + +def detect_metric(model,**kwargs): + backend_id = kwargs.pop('backend', 'default') + data=kwargs['data'] + img_dir=data.pop('imgs') + gt_dir=data.pop('gt') + out_dir=data.pop('out') + width=data.pop('width') + height=data.pop('height') + iou_constraint=data.pop("iou_constraint") + area_precision_constraint=data.pop("area_precision_constraint") + available_backends = dict( + default=cv.dnn.DNN_BACKEND_DEFAULT, + # halide=cv.dnn.DNN_BACKEND_HALIDE, + # inference_engine=cv.dnn.DNN_BACKEND_INFERENCE_ENGINE, + opencv=cv.dnn.DNN_BACKEND_OPENCV, + # vkcom=cv.dnn.DNN_BACKEND_VKCOM, + cuda=cv.dnn.DNN_BACKEND_CUDA, + ) + + target_id = kwargs.pop('target', 'cpu') + available_targets = dict( + cpu=cv.dnn.DNN_TARGET_CPU, + cuda=cv.dnn.DNN_TARGET_CUDA, + cuda_fp16=cv.dnn.DNN_TARGET_CUDA_FP16, + ) + + # add extra backends & targets + try: + available_backends['timvx'] = cv.dnn.DNN_BACKEND_TIMVX + available_targets['npu'] = cv.dnn.DNN_TARGET_NPU + except: + print( + 'OpenCV is not compiled with TIM-VX backend enbaled. See https://github.com/opencv/opencv/wiki/TIM-VX-Backend-For-Running-OpenCV-On-NPU for more details on how to enable TIM-VX backend.') + + _backend = available_backends[backend_id] + _target = available_targets[target_id] + model.setBackend(_backend) + model.setTarget(_target) + default_evaluation_params = { + 'IOU_CONSTRAINT': iou_constraint, + 'AREA_PRECISION_CONSTRAINT': area_precision_constraint, + 'GT_SAMPLE_NAME_2_ID': 'gt_img_([0-9]+).txt', + 'DET_SAMPLE_NAME_2_ID': 'res_img_([0-9]+).txt', + 'LTRB': False, # LTRB:2points(left,top,right,bottom) or 4 points(x1,y1,x2,y2,x3,y3,x4,y4) + 'CRLF': False, # Lines are delimited by Windows CRLF format + 'CONFIDENCES': True, # Detections must include confidence value. AP will be calculated + 'PER_SAMPLE_RESULTS': False # Generate per sample results and produce data for visualization + } + # Instantiate DB + + files = glob.glob(img_dir + '/*', recursive=True) + + for file in files: + image = cv.imread(file) + scale = (image.shape[1] * 1.0 / width, image.shape[0] * 1.0 / height) + image = cv.resize(image, [width, height]) + # Inference + results = model.infer(image) + img_name = file.split('/')[-1].split('.')[0] + text_file = out_dir + '/res_' + img_name + '.txt' + result = '' + for idx, (bbox, score) in enumerate(zip(results[0], results[1])): + result += '{},{},{},{},{},{},{},{},{}\n'.format(int(bbox[0][0] * scale[0]), int(bbox[0][1] * scale[1]), + int(bbox[1][0] * scale[0]), int(bbox[1][1] * scale[1]), + int(bbox[2][0] * scale[0]), int(bbox[2][1] * scale[1]), + int(bbox[3][0] * scale[0]), int(bbox[3][1] * scale[1]), + score) + with open(text_file, 'w+') as fid: + fid.write(result) + result_dict = cal_recall_precison_f1(gt_dir, out_dir, default_evaluation_params) + print(result_dict) + return result_dict diff --git a/models/text_detection_db/cal_rescall/__init__.py b/models/text_detection_db/cal_rescall/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/models/text_detection_db/demo.py b/models/text_detection_db/demo.py index b34e5491..312dd508 100644 --- a/models/text_detection_db/demo.py +++ b/models/text_detection_db/demo.py @@ -89,7 +89,7 @@ def visualize(image, results, box_color=(0, 255, 0), text_color=(0, 0, 255), isC # Save results if save is true if args.save: - print('Resutls saved to result.jpg\n') + print('Results saved to result.jpg\n') cv.imwrite('result.jpg', image) # Visualize results in a new window