From 6001f6099cc2f99041f6fb466841ff8c9c85ced7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E9=80=9A?= <11911611@mail.sustech.edu.cn> Date: Wed, 13 Jul 2022 21:48:33 +0800 Subject: [PATCH 1/5] add text detection metrics --- .../text_detection_db/cal_rescall/__init__.py | 0 .../cal_rescall/rrc_evaluation_funcs.py | 387 ++++++++++++++++++ models/text_detection_db/demo.py | 2 +- models/text_detection_db/validation.py | 106 +++++ 4 files changed, 494 insertions(+), 1 deletion(-) create mode 100644 models/text_detection_db/cal_rescall/__init__.py create mode 100644 models/text_detection_db/cal_rescall/rrc_evaluation_funcs.py create mode 100644 models/text_detection_db/validation.py diff --git a/models/text_detection_db/cal_rescall/__init__.py b/models/text_detection_db/cal_rescall/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/models/text_detection_db/cal_rescall/rrc_evaluation_funcs.py b/models/text_detection_db/cal_rescall/rrc_evaluation_funcs.py new file mode 100644 index 00000000..d6ac1c3d --- /dev/null +++ b/models/text_detection_db/cal_rescall/rrc_evaluation_funcs.py @@ -0,0 +1,387 @@ +#encoding: UTF-8 +import json +import sys +sys.path.append('./') +import zipfile +import re +import sys +import os +import codecs +import traceback + +def print_help(): + sys.stdout.write('Usage: python %s.py -g= -s= [-o= -p=]' %sys.argv[0]) + sys.exit(2) + + +def load_zip_file_keys(file,fileNameRegExp=''): + """ + Returns an array with the entries of the ZIP file that match with the regular expression. + The key's are the names or the file or the capturing group definied in the fileNameRegExp + """ + try: + archive=zipfile.ZipFile(file, mode='r', allowZip64=True) + except : + raise Exception('Error loading the ZIP archive.') + + pairs = [] + + for name in archive.namelist(): + addFile = True + keyName = name + if fileNameRegExp!="": + m = re.match(fileNameRegExp,name) + if m == None: + addFile = False + else: + if len(m.groups())>0: + keyName = m.group(1) + + if addFile: + pairs.append( keyName ) + + return pairs + + +def load_zip_file(file,fileNameRegExp='',allEntries=False): + """ + Returns an array with the contents (filtered by fileNameRegExp) of a ZIP file. + The key's are the names or the file or the capturing group definied in the fileNameRegExp + allEntries validates that all entries in the ZIP file pass the fileNameRegExp + """ + try: + archive=zipfile.ZipFile(file, mode='r', allowZip64=True) + except : + raise Exception('Error loading the ZIP archive') + + pairs = [] + for name in archive.namelist(): + addFile = True + keyName = name + if fileNameRegExp!="": + m = re.match(fileNameRegExp,name) + if m == None: + addFile = False + else: + if len(m.groups())>0: + keyName = m.group(1) + + if addFile: + pairs.append( [ keyName , archive.read(name)] ) + else: + if allEntries: + raise Exception('ZIP entry not valid: %s' %name) + + return dict(pairs) + + +def load_folder_file(file, fileNameRegExp='', allEntries=False): + """ + Returns an array with the contents (filtered by fileNameRegExp) of a ZIP file. + The key's are the names or the file or the capturing group definied in the fileNameRegExp + allEntries validates that all entries in the ZIP file pass the fileNameRegExp + """ + pairs = [] + for name in os.listdir(file): + addFile = True + keyName = name + if fileNameRegExp != "": + m = re.match(fileNameRegExp, name) + if m == None: + addFile = False + else: + if len(m.groups()) > 0: + keyName = m.group(1) + + if addFile: + pairs.append([keyName, open(os.path.join(file,name)).read()]) + else: + if allEntries: + raise Exception('ZIP entry not valid: %s' % name) + + return dict(pairs) + + +def decode_utf8(raw): + """ + Returns a Unicode object on success, or None on failure + """ + try: + raw = codecs.decode(raw,'utf-8', 'replace') + #extracts BOM if exists + raw = raw.encode('utf8') + if raw.startswith(codecs.BOM_UTF8): + raw = raw.replace(codecs.BOM_UTF8, '', 1) + return raw.decode('utf-8') + except: + return None + +def validate_lines_in_file(fileName,file_contents,CRLF=True,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0): + """ + This function validates that all lines of the file calling the Line validation function for each line + """ + utf8File = decode_utf8(file_contents) + if (utf8File is None) : + raise Exception("The file %s is not UTF-8" %fileName) + + lines = utf8File.split( "\r\n" if CRLF else "\n" ) + for line in lines: + line = line.replace("\r","").replace("\n","") + if(line != ""): + try: + validate_tl_line(line,LTRB,withTranscription,withConfidence,imWidth,imHeight) + except Exception as e: + raise Exception(("Line in sample not valid. Sample: %s Line: %s Error: %s" %(fileName,line,str(e))).encode('utf-8', 'replace')) + + + +def validate_tl_line(line,LTRB=True,withTranscription=True,withConfidence=True,imWidth=0,imHeight=0): + """ + Validate the format of the line. If the line is not valid an exception will be raised. + If maxWidth and maxHeight are specified, all points must be inside the imgage bounds. + Posible values are: + LTRB=True: xmin,ymin,xmax,ymax[,confidence][,transcription] + LTRB=False: x1,y1,x2,y2,x3,y3,x4,y4[,confidence][,transcription] + """ + get_tl_line_values(line,LTRB,withTranscription,withConfidence,imWidth,imHeight) + + +def get_tl_line_values(line,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0): + """ + Validate the format of the line. If the line is not valid an exception will be raised. + If maxWidth and maxHeight are specified, all points must be inside the imgage bounds. + Posible values are: + LTRB=True: xmin,ymin,xmax,ymax[,confidence][,transcription] + LTRB=False: x1,y1,x2,y2,x3,y3,x4,y4[,confidence][,transcription] + Returns values from a textline. Points , [Confidences], [Transcriptions] + """ + confidence = 0.0 + transcription = "" + points = [] + + numPoints = 4 + + if LTRB: + + numPoints = 4 + + if withTranscription and withConfidence: + m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$',line) + if m == None : + m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$',line) + raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,confidence,transcription") + elif withConfidence: + m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*$',line) + if m == None : + raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,confidence") + elif withTranscription: + m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,(.*)$',line) + if m == None : + raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,transcription") + else: + m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,?\s*$',line) + if m == None : + raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax") + + xmin = int(m.group(1)) + ymin = int(m.group(2)) + xmax = int(m.group(3)) + ymax = int(m.group(4)) + if(xmax0 and imHeight>0): + validate_point_inside_bounds(xmin,ymin,imWidth,imHeight) + validate_point_inside_bounds(xmax,ymax,imWidth,imHeight) + + else: + + numPoints = 8 + + if withTranscription and withConfidence: + m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$',line) + if m == None : + raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,confidence,transcription") + elif withConfidence: + m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-1].?[0-9]*)\s*$',line) + if m == None : + raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,confidence") + elif withTranscription: + m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,(.*)$',line) + if m == None : + raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,transcription") + else: + m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*$',line) + if m == None : + raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4") + + points = [ float(m.group(i)) for i in range(1, (numPoints+1) ) ] + + validate_clockwise_points(points) + + if (imWidth>0 and imHeight>0): + validate_point_inside_bounds(points[0],points[1],imWidth,imHeight) + validate_point_inside_bounds(points[2],points[3],imWidth,imHeight) + validate_point_inside_bounds(points[4],points[5],imWidth,imHeight) + validate_point_inside_bounds(points[6],points[7],imWidth,imHeight) + + + if withConfidence: + try: + confidence = float(m.group(numPoints+1)) + except ValueError: + raise Exception("Confidence value must be a float") + + if withTranscription: + posTranscription = numPoints + (2 if withConfidence else 1) + transcription = m.group(posTranscription) + m2 = re.match(r'^\s*\"(.*)\"\s*$',transcription) + if m2 != None : #Transcription with double quotes, we extract the value and replace escaped characters + transcription = m2.group(1).replace("\\\\", "\\").replace("\\\"", "\"") + + return points,confidence,transcription + + +def validate_point_inside_bounds(x,y,imWidth,imHeight): + if(x<0 or x>imWidth): + raise Exception("X value (%s) not valid. Image dimensions: (%s,%s)" %(x,imWidth,imHeight)) + if(y<0 or y>imHeight): + raise Exception("Y value (%s) not valid. Image dimensions: (%s,%s) Sample: %s Line:%s" %(y,imWidth,imHeight)) + +def validate_clockwise_points(points): + """ + Validates that the points that the 4 points that dlimite a polygon are in clockwise order. + """ + + if len(points) != 8: + raise Exception("Points list not valid." + str(len(points))) + + point = [ + [int(points[0]) , int(points[1])], + [int(points[2]) , int(points[3])], + [int(points[4]) , int(points[5])], + [int(points[6]) , int(points[7])] + ] + edge = [ + ( point[1][0] - point[0][0])*( point[1][1] + point[0][1]), + ( point[2][0] - point[1][0])*( point[2][1] + point[1][1]), + ( point[3][0] - point[2][0])*( point[3][1] + point[2][1]), + ( point[0][0] - point[3][0])*( point[0][1] + point[3][1]) + ] + + summatory = edge[0] + edge[1] + edge[2] + edge[3]; + if summatory>0: + raise Exception("Points are not clockwise. The coordinates of bounding quadrilaterals have to be given in clockwise order. Regarding the correct interpretation of 'clockwise' remember that the train_images coordinate system used is the standard one, with the train_images origin at the upper left, the X axis extending to the right and Y axis extending downwards.") + +def get_tl_line_values_from_file_contents(content,CRLF=True,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0,sort_by_confidences=True): + """ + Returns all points, confindences and transcriptions of a file in lists. Valid line formats: + xmin,ymin,xmax,ymax,[confidence],[transcription] + x1,y1,x2,y2,x3,y3,x4,y4,[confidence],[transcription] + """ + pointsList = [] + transcriptionsList = [] + confidencesList = [] + + lines = content.split( "\r\n" if CRLF else "\n" ) + for line in lines: + line = line.replace("\r","").replace("\n","") + if(line != "") : + points, confidence, transcription = get_tl_line_values(line,LTRB,withTranscription,withConfidence,imWidth,imHeight); + pointsList.append(points) + transcriptionsList.append(transcription) + confidencesList.append(confidence) + + if withConfidence and len(confidencesList)>0 and sort_by_confidences: + import numpy as np + sorted_ind = np.argsort(-np.array(confidencesList)) + confidencesList = [confidencesList[i] for i in sorted_ind] + pointsList = [pointsList[i] for i in sorted_ind] + transcriptionsList = [transcriptionsList[i] for i in sorted_ind] + + return pointsList,confidencesList,transcriptionsList + +def main_evaluation(p,evalParams,evaluate_method_fn,show_result=True,per_sample=True): + """ + This process validates a method, evaluates it and if it succed generates a ZIP file with a JSON entry for each sample. + Params: + p: Dictionary of parmeters with the GT/submission locations. If None is passed, the parameters send by the system are used. + default_evaluation_params_fn: points to a function that returns a dictionary with the default parameters used for the evaluation + validate_data_fn: points to a method that validates the corrct format of the submission + evaluate_method_fn: points to a function that evaluated the submission and return a Dictionary with the results + """ + if 'p' in p.keys(): + evalParams.update( p['p'] if isinstance(p['p'], dict) else json.loads(p['p'][1:-1]) ) + + resDict={'calculated':True,'Message':'','method':'{}','per_sample':'{}'} + try: + # validate_data_fn(p['g'], p['s'], evalParams) + evalData = evaluate_method_fn(p['g'], p['s'], evalParams) + resDict.update(evalData) + + except Exception as e: + traceback.print_exc() + resDict['Message']= str(e) + resDict['calculated']=False + + if 'o' in p: + if not os.path.exists(p['o']): + os.makedirs(p['o']) + + resultsOutputname = p['o'] + '/results.zip' + outZip = zipfile.ZipFile(resultsOutputname, mode='w', allowZip64=True) + + del resDict['per_sample'] + if 'output_items' in resDict.keys(): + del resDict['output_items'] + + outZip.writestr('method.json',json.dumps(resDict)) + + if not resDict['calculated']: + if show_result: + sys.stderr.write('Error!\n'+ resDict['Message']+'\n\n') + if 'o' in p: + outZip.close() + return resDict + + if 'o' in p: + if per_sample == True: + for k,v in evalData['per_sample'].iteritems(): + outZip.writestr( k + '.json',json.dumps(v)) + + if 'output_items' in evalData.keys(): + for k, v in evalData['output_items'].iteritems(): + outZip.writestr( k,v) + + outZip.close() + + if show_result: + sys.stdout.write("Calculated!") + sys.stdout.write(json.dumps(resDict['method'])) + + return resDict + + +def main_validation(default_evaluation_params_fn,validate_data_fn): + """ + This process validates a method + Params: + default_evaluation_params_fn: points to a function that returns a dictionary with the default parameters used for the evaluation + validate_data_fn: points to a method that validates the corrct format of the submission + """ + try: + p = dict([s[1:].split('=') for s in sys.argv[1:]]) + evalParams = default_evaluation_params_fn() + if 'p' in p.keys(): + evalParams.update( p['p'] if isinstance(p['p'], dict) else json.loads(p['p'][1:-1]) ) + + validate_data_fn(p['g'], p['s'], evalParams) + print('SUCCESS') + sys.exit(0) + except Exception as e: + print(str(e)) + sys.exit(101) \ No newline at end of file diff --git a/models/text_detection_db/demo.py b/models/text_detection_db/demo.py index b34e5491..312dd508 100644 --- a/models/text_detection_db/demo.py +++ b/models/text_detection_db/demo.py @@ -89,7 +89,7 @@ def visualize(image, results, box_color=(0, 255, 0), text_color=(0, 0, 255), isC # Save results if save is true if args.save: - print('Resutls saved to result.jpg\n') + print('Results saved to result.jpg\n') cv.imwrite('result.jpg', image) # Visualize results in a new window diff --git a/models/text_detection_db/validation.py b/models/text_detection_db/validation.py new file mode 100644 index 00000000..24928b83 --- /dev/null +++ b/models/text_detection_db/validation.py @@ -0,0 +1,106 @@ +# This file is part of OpenCV Zoo project. +# It is subject to the license terms in the LICENSE file found in the same directory. +# +# Copyright (C) 2021, Shenzhen Institute of Artificial Intelligence and Robotics for Society, all rights reserved. +# Third party copyrights are property of their respective owners. + +import argparse +import glob +import sys +import time + +import numpy as np +import cv2 as cv + +from db import DB +from cal_rescall.script import cal_recall_precison_f1 + + +def str2bool(v): + if v.lower() in ['on', 'yes', 'true', 'y', 't']: + return True + elif v.lower() in ['off', 'no', 'false', 'n', 'f']: + return False + else: + raise NotImplementedError + +backends = [cv.dnn.DNN_BACKEND_OPENCV, cv.dnn.DNN_BACKEND_CUDA] +targets = [cv.dnn.DNN_TARGET_CPU, cv.dnn.DNN_TARGET_CUDA, cv.dnn.DNN_TARGET_CUDA_FP16] +help_msg_backends = "Choose one of the computation backends: {:d}: OpenCV implementation (default); {:d}: CUDA" +help_msg_targets = "Chose one of the target computation devices: {:d}: CPU (default); {:d}: CUDA; {:d}: CUDA fp16" +try: + backends += [cv.dnn.DNN_BACKEND_TIMVX] + targets += [cv.dnn.DNN_TARGET_NPU] + help_msg_backends += "; {:d}: TIMVX" + help_msg_targets += "; {:d}: NPU" +except: + print('This version of OpenCV does not support TIM-VX and NPU. Visit https://gist.github.com/fengyuentau/5a7a5ba36328f2b763aea026c43fa45f for more information.') + +parser = argparse.ArgumentParser(description='Real-time Scene Text Detection with Differentiable Binarization (https://arxiv.org/abs/1911.08947).') +parser.add_argument('--input', '-i', type=str, help='Path to the input image. Omit for using default camera.') +parser.add_argument('--model', '-m', type=str, default='text_detection_DB_TD500_resnet18_2021sep.onnx', help='Path to the model.') +parser.add_argument('--gt_dir', type=str, default='icdar2015/test_gts', help='Path to the ground truth txt directory.') +parser.add_argument('--out_dir', type=str, default='icdar2015/test_predicts', help='Path to the output txt directory.') +parser.add_argument('--img_dir', type=str, default='icdar2015/test_images', help='Path to the test images directory.') +parser.add_argument('--backend', '-b', type=int, default=backends[0], help=help_msg_backends.format(*backends)) +parser.add_argument('--target', '-t', type=int, default=targets[0], help=help_msg_targets.format(*targets)) +parser.add_argument('--width', type=int, default=736, + help='Preprocess input image by resizing to a specific width. It should be multiple by 32.') +parser.add_argument('--height', type=int, default=736, + help='Preprocess input image by resizing to a specific height. It should be multiple by 32.') +parser.add_argument('--binary_threshold', type=float, default=0.3, help='Threshold of the binary map.') +parser.add_argument('--polygon_threshold', type=float, default=0.5, help='Threshold of polygons.') +parser.add_argument('--iou_constraint', type=float, default=0.5, help='IOU constraint.') +parser.add_argument('--area_precision_constraint', type=float, default=0.3, help='Area precision constraint.') +parser.add_argument('--max_candidates', type=int, default=200, help='Max candidates of polygons.') +parser.add_argument('--unclip_ratio', type=np.float64, default=2.0, help=' The unclip ratio of the detected text region, which determines the output size.') +parser.add_argument('--save', '-s', type=str, default=False, help='Set true to save results. This flag is invalid when using camera.') +parser.add_argument('--vis', '-v', type=str2bool, default=True, help='Set true to open a window for result visualization. This flag is invalid when using camera.') +args = parser.parse_args() + + + +if __name__ == '__main__': + default_evaluation_params={ + 'IOU_CONSTRAINT': args.iou_constraint, + 'AREA_PRECISION_CONSTRAINT': args.area_precision_constraint, + 'GT_SAMPLE_NAME_2_ID': 'gt_img_([0-9]+).txt', + 'DET_SAMPLE_NAME_2_ID': 'res_img_([0-9]+).txt', + 'LTRB': False, # LTRB:2points(left,top,right,bottom) or 4 points(x1,y1,x2,y2,x3,y3,x4,y4) + 'CRLF': False, # Lines are delimited by Windows CRLF format + 'CONFIDENCES': True, # Detections must include confidence value. AP will be calculated + 'PER_SAMPLE_RESULTS':False # Generate per sample results and produce data for visualization + } + # Instantiate DB + model = DB(modelPath=args.model, + inputSize=[args.width, args.height], + binaryThreshold=args.binary_threshold, + polygonThreshold=args.polygon_threshold, + maxCandidates=args.max_candidates, + unclipRatio=args.unclip_ratio, + backendId=args.backend, + targetId=args.target + ) + files = glob.glob(args.img_dir+'/*', recursive=True) + start = time.time() + for file in files: + image = cv.imread(file) + image = cv.resize(image, [args.width, args.height]) + + # Inference + results = model.infer(image) + img_name=file.split('/')[-1].split('.')[0] + text_file = args.out_dir+'res_' + img_name + '.txt' + result='' + for idx, (bbox, score) in enumerate(zip(results[0], results[1])): + result+='{},{},{},{},{},{},{},{},{}\n'.format(bbox[0][0],bbox[0][1], bbox[1][0], bbox[1][1], bbox[2][0],bbox[2][1], bbox[3][0], bbox[3][1],score) + with open(text_file, 'w+') as fid: + fid.write(result) + end = time.time() + avg_time=(end-start)/len(files) + result_dict = cal_recall_precison_f1(args.gt_dir,args.out_dir,default_evaluation_params) + # Print results + result_dict['avg time']=avg_time + print(result_dict) + + From 599f7b318ca43fd201725c8afa31722ab882eec5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E9=80=9A?= <11911611@mail.sustech.edu.cn> Date: Wed, 13 Jul 2022 21:49:14 +0800 Subject: [PATCH 2/5] add text detection metrics --- .../text_detection_db/cal_rescall/script.py | 284 ++++++++++++++++++ 1 file changed, 284 insertions(+) create mode 100644 models/text_detection_db/cal_rescall/script.py diff --git a/models/text_detection_db/cal_rescall/script.py b/models/text_detection_db/cal_rescall/script.py new file mode 100644 index 00000000..06039b5f --- /dev/null +++ b/models/text_detection_db/cal_rescall/script.py @@ -0,0 +1,284 @@ +# -*- coding: utf-8 -*- +from collections import namedtuple +from . import rrc_evaluation_funcs +import Polygon as plg +import numpy as np + + +def evaluate_method(gtFilePath, submFilePath, evaluationParams): + """ + Method evaluate_method: evaluate method and returns the results + Results. Dictionary with the following values: + - method (required) Global method metrics. Ex: { 'Precision':0.8,'Recall':0.9 } + - samples (optional) Per sample metrics. Ex: {'sample1' : { 'Precision':0.8,'Recall':0.9 } , 'sample2' : { 'Precision':0.8,'Recall':0.9 } + """ + + def polygon_from_points(points): + """ + Returns a Polygon object to use with the Polygon2 class from a list of 8 points: x1,y1,x2,y2,x3,y3,x4,y4 + """ + resBoxes = np.empty([1, 8], dtype='int32') + resBoxes[0, 0] = int(points[0]) + resBoxes[0, 4] = int(points[1]) + resBoxes[0, 1] = int(points[2]) + resBoxes[0, 5] = int(points[3]) + resBoxes[0, 2] = int(points[4]) + resBoxes[0, 6] = int(points[5]) + resBoxes[0, 3] = int(points[6]) + resBoxes[0, 7] = int(points[7]) + pointMat = resBoxes[0].reshape([2, 4]).T + return plg.Polygon(pointMat) + + def rectangle_to_polygon(rect): + resBoxes = np.empty([1, 8], dtype='int32') + resBoxes[0, 0] = int(rect.xmin) + resBoxes[0, 4] = int(rect.ymax) + resBoxes[0, 1] = int(rect.xmin) + resBoxes[0, 5] = int(rect.ymin) + resBoxes[0, 2] = int(rect.xmax) + resBoxes[0, 6] = int(rect.ymin) + resBoxes[0, 3] = int(rect.xmax) + resBoxes[0, 7] = int(rect.ymax) + + pointMat = resBoxes[0].reshape([2, 4]).T + + return plg.Polygon(pointMat) + + def rectangle_to_points(rect): + points = [int(rect.xmin), int(rect.ymax), int(rect.xmax), int(rect.ymax), int(rect.xmax), int(rect.ymin), + int(rect.xmin), int(rect.ymin)] + return points + + def get_union(pD, pG): + areaA = pD.area() + areaB = pG.area() + return areaA + areaB - get_intersection(pD, pG) + + def get_intersection_over_union(pD, pG): + try: + return get_intersection(pD, pG) / get_union(pD, pG) + except: + return 0 + + def get_intersection(pD, pG): + pInt = pD & pG + if len(pInt) == 0: + return 0 + return pInt.area() + + def compute_ap(confList, matchList, numGtCare): + correct = 0 + AP = 0 + if len(confList) > 0: + confList = np.array(confList) + matchList = np.array(matchList) + sorted_ind = np.argsort(-confList) + confList = confList[sorted_ind] + matchList = matchList[sorted_ind] + for n in range(len(confList)): + match = matchList[n] + if match: + correct += 1 + AP += float(correct) / (n + 1) + + if numGtCare > 0: + AP /= numGtCare + + return AP + + perSampleMetrics = {} + + matchedSum = 0 + + Rectangle = namedtuple('Rectangle', 'xmin ymin xmax ymax') + + gt = rrc_evaluation_funcs.load_folder_file(gtFilePath, evaluationParams['GT_SAMPLE_NAME_2_ID']) + subm = rrc_evaluation_funcs.load_folder_file(submFilePath, evaluationParams['DET_SAMPLE_NAME_2_ID'], True) + + numGlobalCareGt = 0 + numGlobalCareDet = 0 + + arrGlobalConfidences = [] + arrGlobalMatches = [] + + for resFile in gt: + + gtFile = gt[resFile] # rrc_evaluation_funcs.decode_utf8(gt[resFile]) + recall = 0 + precision = 0 + hmean = 0 + + detMatched = 0 + + iouMat = np.empty([1, 1]) + + gtPols = [] + detPols = [] + + gtPolPoints = [] + detPolPoints = [] + + # Array of Ground Truth Polygons' keys marked as don't Care + gtDontCarePolsNum = [] + # Array of Detected Polygons' matched with a don't Care GT + detDontCarePolsNum = [] + + pairs = [] + detMatchedNums = [] + + arrSampleConfidences = [] + arrSampleMatch = [] + sampleAP = 0 + + evaluationLog = "" + + pointsList, _, transcriptionsList = rrc_evaluation_funcs.get_tl_line_values_from_file_contents(gtFile, + evaluationParams[ + 'CRLF'], + evaluationParams[ + 'LTRB'], + True, False) + for n in range(len(pointsList)): + points = pointsList[n] + transcription = transcriptionsList[n] + dontCare = transcription == "###" + if evaluationParams['LTRB']: + gtRect = Rectangle(*points) + gtPol = rectangle_to_polygon(gtRect) + else: + gtPol = polygon_from_points(points) + gtPols.append(gtPol) + gtPolPoints.append(points) + if dontCare: + gtDontCarePolsNum.append(len(gtPols) - 1) + + evaluationLog += "GT polygons: " + str(len(gtPols)) + ( + " (" + str(len(gtDontCarePolsNum)) + " don't care)\n" if len(gtDontCarePolsNum) > 0 else "\n") + + if resFile in subm: + + detFile = subm[resFile] # rrc_evaluation_funcs.decode_utf8(subm[resFile]) + + pointsList, confidencesList, _ = rrc_evaluation_funcs.get_tl_line_values_from_file_contents(detFile, + evaluationParams[ + 'CRLF'], + evaluationParams[ + 'LTRB'], + False, + evaluationParams[ + 'CONFIDENCES']) + for n in range(len(pointsList)): + points = pointsList[n] + + if evaluationParams['LTRB']: + detRect = Rectangle(*points) + detPol = rectangle_to_polygon(detRect) + else: + detPol = polygon_from_points(points) + detPols.append(detPol) + detPolPoints.append(points) + if len(gtDontCarePolsNum) > 0: + for dontCarePol in gtDontCarePolsNum: + dontCarePol = gtPols[dontCarePol] + intersected_area = get_intersection(dontCarePol, detPol) + pdDimensions = detPol.area() + precision = 0 if pdDimensions == 0 else intersected_area / pdDimensions + if (precision > evaluationParams['AREA_PRECISION_CONSTRAINT']): + detDontCarePolsNum.append(len(detPols) - 1) + break + + evaluationLog += "DET polygons: " + str(len(detPols)) + ( + " (" + str(len(detDontCarePolsNum)) + " don't care)\n" if len(detDontCarePolsNum) > 0 else "\n") + + if len(gtPols) > 0 and len(detPols) > 0: + # Calculate IoU and precision matrixs + outputShape = [len(gtPols), len(detPols)] + iouMat = np.empty(outputShape) + gtRectMat = np.zeros(len(gtPols), np.int8) + detRectMat = np.zeros(len(detPols), np.int8) + for gtNum in range(len(gtPols)): + for detNum in range(len(detPols)): + pG = gtPols[gtNum] + pD = detPols[detNum] + iouMat[gtNum, detNum] = get_intersection_over_union(pD, pG) + + for gtNum in range(len(gtPols)): + for detNum in range(len(detPols)): + if gtRectMat[gtNum] == 0 and detRectMat[ + detNum] == 0 and gtNum not in gtDontCarePolsNum and detNum not in detDontCarePolsNum: + if iouMat[gtNum, detNum] > evaluationParams['IOU_CONSTRAINT']: + gtRectMat[gtNum] = 1 + detRectMat[detNum] = 1 + detMatched += 1 + pairs.append({'gt': gtNum, 'det': detNum}) + detMatchedNums.append(detNum) + evaluationLog += "Match GT #" + str(gtNum) + " with Det #" + str(detNum) + "\n" + + if evaluationParams['CONFIDENCES']: + for detNum in range(len(detPols)): + if detNum not in detDontCarePolsNum: + # we exclude the don't care detections + match = detNum in detMatchedNums + + arrSampleConfidences.append(confidencesList[detNum]) + arrSampleMatch.append(match) + + arrGlobalConfidences.append(confidencesList[detNum]) + arrGlobalMatches.append(match) + + numGtCare = (len(gtPols) - len(gtDontCarePolsNum)) + numDetCare = (len(detPols) - len(detDontCarePolsNum)) + if numGtCare == 0: + recall = float(1) + precision = float(0) if numDetCare > 0 else float(1) + sampleAP = precision + else: + recall = float(detMatched) / numGtCare + precision = 0 if numDetCare == 0 else float(detMatched) / numDetCare + if evaluationParams['CONFIDENCES'] and evaluationParams['PER_SAMPLE_RESULTS']: + sampleAP = compute_ap(arrSampleConfidences, arrSampleMatch, numGtCare) + + hmean = 0 if (precision + recall) == 0 else 2.0 * precision * recall / (precision + recall) + + matchedSum += detMatched + numGlobalCareGt += numGtCare + numGlobalCareDet += numDetCare + + if evaluationParams['PER_SAMPLE_RESULTS']: + perSampleMetrics[resFile] = { + 'precision': precision, + 'recall': recall, + 'hmean': hmean, + 'pairs': pairs, + 'AP': sampleAP, + 'iouMat': [] if len(detPols) > 100 else iouMat.tolist(), + 'gtPolPoints': gtPolPoints, + 'detPolPoints': detPolPoints, + 'gtDontCare': gtDontCarePolsNum, + 'detDontCare': detDontCarePolsNum, + 'evaluationParams': evaluationParams, + 'evaluationLog': evaluationLog + } + + # Compute MAP and MAR + AP = 0 + if evaluationParams['CONFIDENCES']: + AP = compute_ap(arrGlobalConfidences, arrGlobalMatches, numGlobalCareGt) + + methodRecall = 0 if numGlobalCareGt == 0 else float(matchedSum) / numGlobalCareGt + methodPrecision = 0 if numGlobalCareDet == 0 else float(matchedSum) / numGlobalCareDet + methodHmean = 0 if methodRecall + methodPrecision == 0 else 2 * methodRecall * methodPrecision / ( + methodRecall + methodPrecision) + + methodMetrics = {'precision': methodPrecision, 'recall': methodRecall, 'hmean': methodHmean, 'AP': AP} + + resDict = {'calculated': True, 'Message': '', 'method': methodMetrics, 'per_sample': perSampleMetrics} + + return resDict + + +def cal_recall_precison_f1(gt_path, result_path, default_evaluation_params, show_result=False): + p = {'g': gt_path, 's': result_path} + result = rrc_evaluation_funcs.main_evaluation(p, default_evaluation_params, evaluate_method, + show_result) + return result['method'] \ No newline at end of file From 4fa742329c0686b429ed672655a83e1e6eef231d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E9=80=9A?= <11911611@mail.sustech.edu.cn> Date: Thu, 14 Jul 2022 21:52:04 +0800 Subject: [PATCH 3/5] change shape --- models/text_detection_db/validation.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/models/text_detection_db/validation.py b/models/text_detection_db/validation.py index 24928b83..de272ddf 100644 --- a/models/text_detection_db/validation.py +++ b/models/text_detection_db/validation.py @@ -42,16 +42,16 @@ def str2bool(v): parser.add_argument('--gt_dir', type=str, default='icdar2015/test_gts', help='Path to the ground truth txt directory.') parser.add_argument('--out_dir', type=str, default='icdar2015/test_predicts', help='Path to the output txt directory.') parser.add_argument('--img_dir', type=str, default='icdar2015/test_images', help='Path to the test images directory.') -parser.add_argument('--backend', '-b', type=int, default=backends[0], help=help_msg_backends.format(*backends)) -parser.add_argument('--target', '-t', type=int, default=targets[0], help=help_msg_targets.format(*targets)) parser.add_argument('--width', type=int, default=736, help='Preprocess input image by resizing to a specific width. It should be multiple by 32.') parser.add_argument('--height', type=int, default=736, help='Preprocess input image by resizing to a specific height. It should be multiple by 32.') +parser.add_argument('--backend', '-b', type=int, default=backends[0], help=help_msg_backends.format(*backends)) +parser.add_argument('--target', '-t', type=int, default=targets[0], help=help_msg_targets.format(*targets)) parser.add_argument('--binary_threshold', type=float, default=0.3, help='Threshold of the binary map.') parser.add_argument('--polygon_threshold', type=float, default=0.5, help='Threshold of polygons.') -parser.add_argument('--iou_constraint', type=float, default=0.5, help='IOU constraint.') -parser.add_argument('--area_precision_constraint', type=float, default=0.3, help='Area precision constraint.') +parser.add_argument('--iou_constraint', type=float, default=0.3, help='IOU constraint.') +parser.add_argument('--area_precision_constraint', type=float, default=0.8, help='Area precision constraint.') parser.add_argument('--max_candidates', type=int, default=200, help='Max candidates of polygons.') parser.add_argument('--unclip_ratio', type=np.float64, default=2.0, help=' The unclip ratio of the detected text region, which determines the output size.') parser.add_argument('--save', '-s', type=str, default=False, help='Set true to save results. This flag is invalid when using camera.') @@ -82,22 +82,26 @@ def str2bool(v): targetId=args.target ) files = glob.glob(args.img_dir+'/*', recursive=True) - start = time.time() + tot_time=0 + for file in files: image = cv.imread(file) + scale = (image.shape[1] * 1.0 / args.width, image.shape[0] * 1.0 / args.height) image = cv.resize(image, [args.width, args.height]) - + start = time.time() # Inference results = model.infer(image) + end = time.time() + tot_time+=(end-start)/len(files) img_name=file.split('/')[-1].split('.')[0] - text_file = args.out_dir+'res_' + img_name + '.txt' + text_file = args.out_dir+'/res_' + img_name + '.txt' result='' for idx, (bbox, score) in enumerate(zip(results[0], results[1])): - result+='{},{},{},{},{},{},{},{},{}\n'.format(bbox[0][0],bbox[0][1], bbox[1][0], bbox[1][1], bbox[2][0],bbox[2][1], bbox[3][0], bbox[3][1],score) + result+='{},{},{},{},{},{},{},{},{}\n'.format(int(bbox[0][0]*scale[0]),int(bbox[0][1]*scale[1]), int(bbox[1][0]*scale[0]), int(bbox[1][1]*scale[1]), int(bbox[2][0]*scale[0]),int(bbox[2][1]*scale[1]), int(bbox[3][0]*scale[0]), int(bbox[3][1]*scale[1]),score) with open(text_file, 'w+') as fid: fid.write(result) - end = time.time() - avg_time=(end-start)/len(files) + + avg_time=tot_time/len(files) result_dict = cal_recall_precison_f1(args.gt_dir,args.out_dir,default_evaluation_params) # Print results result_dict['avg time']=avg_time From a2b13938c570f2e7435a9c4484d1f5ee7b6b1983 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E9=80=9A?= <11911611@mail.sustech.edu.cn> Date: Wed, 20 Jul 2022 17:42:41 +0800 Subject: [PATCH 4/5] Format file --- .../cal_rescall/rrc_evaluation_funcs.py | 321 ++++++++++-------- .../text_detection_db/cal_rescall/script.py | 29 +- models/text_detection_db/validation.py | 55 +-- 3 files changed, 219 insertions(+), 186 deletions(-) diff --git a/models/text_detection_db/cal_rescall/rrc_evaluation_funcs.py b/models/text_detection_db/cal_rescall/rrc_evaluation_funcs.py index d6ac1c3d..b31576a3 100644 --- a/models/text_detection_db/cal_rescall/rrc_evaluation_funcs.py +++ b/models/text_detection_db/cal_rescall/rrc_evaluation_funcs.py @@ -1,6 +1,9 @@ -#encoding: UTF-8 +# encoding: UTF-8 +# https://rrc.cvc.uab.es/?ch=2&com=mymethods&task=1 + import json import sys + sys.path.append('./') import zipfile import re @@ -9,68 +12,70 @@ import codecs import traceback + def print_help(): - sys.stdout.write('Usage: python %s.py -g= -s= [-o= -p=]' %sys.argv[0]) + sys.stdout.write('Usage: python %s.py -g= -s= [-o= -p=]' % sys.argv[0]) sys.exit(2) - -def load_zip_file_keys(file,fileNameRegExp=''): + +def load_zip_file_keys(file, fileNameRegExp=''): """ Returns an array with the entries of the ZIP file that match with the regular expression. The key's are the names or the file or the capturing group definied in the fileNameRegExp """ try: - archive=zipfile.ZipFile(file, mode='r', allowZip64=True) - except : + archive = zipfile.ZipFile(file, mode='r', allowZip64=True) + except: raise Exception('Error loading the ZIP archive.') pairs = [] - + for name in archive.namelist(): addFile = True keyName = name - if fileNameRegExp!="": - m = re.match(fileNameRegExp,name) + if fileNameRegExp != "": + m = re.match(fileNameRegExp, name) if m == None: addFile = False else: - if len(m.groups())>0: + if len(m.groups()) > 0: keyName = m.group(1) - + if addFile: - pairs.append( keyName ) - + pairs.append(keyName) + return pairs - -def load_zip_file(file,fileNameRegExp='',allEntries=False): + +def load_zip_file(file, fileNameRegExp='', allEntries=False): """ Returns an array with the contents (filtered by fileNameRegExp) of a ZIP file. The key's are the names or the file or the capturing group definied in the fileNameRegExp allEntries validates that all entries in the ZIP file pass the fileNameRegExp """ try: - archive=zipfile.ZipFile(file, mode='r', allowZip64=True) - except : - raise Exception('Error loading the ZIP archive') + archive = zipfile.ZipFile(file, mode='r', allowZip64=True) + except: + + raise Exception('Error loading the ZIP archive') pairs = [] for name in archive.namelist(): addFile = True keyName = name - if fileNameRegExp!="": - m = re.match(fileNameRegExp,name) + if fileNameRegExp != "": + m = re.match(fileNameRegExp, name) if m == None: addFile = False else: - if len(m.groups())>0: + if len(m.groups()) > 0: keyName = m.group(1) - + if addFile: - pairs.append( [ keyName , archive.read(name)] ) + pairs.append([keyName, archive.read(name)]) else: if allEntries: - raise Exception('ZIP entry not valid: %s' %name) + raise Exception('ZIP entry not valid: %s' % name) return dict(pairs) @@ -94,7 +99,7 @@ def load_folder_file(file, fileNameRegExp='', allEntries=False): keyName = m.group(1) if addFile: - pairs.append([keyName, open(os.path.join(file,name)).read()]) + pairs.append([keyName, open(os.path.join(file, name)).read()]) else: if allEntries: raise Exception('ZIP entry not valid: %s' % name) @@ -107,35 +112,39 @@ def decode_utf8(raw): Returns a Unicode object on success, or None on failure """ try: - raw = codecs.decode(raw,'utf-8', 'replace') - #extracts BOM if exists + raw = codecs.decode(raw, 'utf-8', 'replace') + # extracts BOM if exists raw = raw.encode('utf8') if raw.startswith(codecs.BOM_UTF8): raw = raw.replace(codecs.BOM_UTF8, '', 1) return raw.decode('utf-8') except: - return None - -def validate_lines_in_file(fileName,file_contents,CRLF=True,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0): + return None + + +def validate_lines_in_file(fileName, file_contents, CRLF=True, LTRB=True, withTranscription=False, withConfidence=False, + imWidth=0, imHeight=0): """ This function validates that all lines of the file calling the Line validation function for each line """ utf8File = decode_utf8(file_contents) - if (utf8File is None) : - raise Exception("The file %s is not UTF-8" %fileName) + if (utf8File is None): + raise Exception("The file %s is not UTF-8" % fileName) - lines = utf8File.split( "\r\n" if CRLF else "\n" ) + lines = utf8File.split("\r\n" if CRLF else "\n") for line in lines: - line = line.replace("\r","").replace("\n","") - if(line != ""): + line = line.replace("\r", "").replace("\n", "") + if (line != ""): try: - validate_tl_line(line,LTRB,withTranscription,withConfidence,imWidth,imHeight) + validate_tl_line(line, LTRB, withTranscription, withConfidence, imWidth, imHeight) except Exception as e: - raise Exception(("Line in sample not valid. Sample: %s Line: %s Error: %s" %(fileName,line,str(e))).encode('utf-8', 'replace')) - - - -def validate_tl_line(line,LTRB=True,withTranscription=True,withConfidence=True,imWidth=0,imHeight=0): + + raise Exception( + ("Line in sample not valid. Sample: %s Line: %s Error: %s" % (fileName, line, str(e))).encode( + 'utf-8', 'replace')) + + +def validate_tl_line(line, LTRB=True, withTranscription=True, withConfidence=True, imWidth=0, imHeight=0): """ Validate the format of the line. If the line is not valid an exception will be raised. If maxWidth and maxHeight are specified, all points must be inside the imgage bounds. @@ -143,10 +152,10 @@ def validate_tl_line(line,LTRB=True,withTranscription=True,withConfidence=True,i LTRB=True: xmin,ymin,xmax,ymax[,confidence][,transcription] LTRB=False: x1,y1,x2,y2,x3,y3,x4,y4[,confidence][,transcription] """ - get_tl_line_values(line,LTRB,withTranscription,withConfidence,imWidth,imHeight) - - -def get_tl_line_values(line,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0): + get_tl_line_values(line, LTRB, withTranscription, withConfidence, imWidth, imHeight) + + +def get_tl_line_values(line, LTRB=True, withTranscription=False, withConfidence=False, imWidth=0, imHeight=0): """ Validate the format of the line. If the line is not valid an exception will be raised. If maxWidth and maxHeight are specified, all points must be inside the imgage bounds. @@ -158,126 +167,144 @@ def get_tl_line_values(line,LTRB=True,withTranscription=False,withConfidence=Fal confidence = 0.0 transcription = "" points = [] - + numPoints = 4 - + if LTRB: - + numPoints = 4 - + if withTranscription and withConfidence: - m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$',line) - if m == None : - m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$',line) + m = re.match( + r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$', line) + if m == None: + m = re.match( + r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$', + line) + raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,confidence,transcription") elif withConfidence: - m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*$',line) - if m == None : + m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*$', + line) + if m == None: raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,confidence") elif withTranscription: - m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,(.*)$',line) - if m == None : + m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,(.*)$', line) + if m == None: raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,transcription") else: - m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,?\s*$',line) - if m == None : + m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,?\s*$', line) + if m == None: raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax") - + xmin = int(m.group(1)) ymin = int(m.group(2)) xmax = int(m.group(3)) ymax = int(m.group(4)) - if(xmax0 and imHeight>0): - validate_point_inside_bounds(xmin,ymin,imWidth,imHeight) - validate_point_inside_bounds(xmax,ymax,imWidth,imHeight) + points = [float(m.group(i)) for i in range(1, (numPoints + 1))] + + if (imWidth > 0 and imHeight > 0): + validate_point_inside_bounds(xmin, ymin, imWidth, imHeight) + validate_point_inside_bounds(xmax, ymax, imWidth, imHeight) else: - + numPoints = 8 - + if withTranscription and withConfidence: - m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$',line) - if m == None : + m = re.match( + r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$', + line) + if m == None: raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,confidence,transcription") elif withConfidence: - m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-1].?[0-9]*)\s*$',line) - if m == None : + m = re.match( + r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-1].?[0-9]*)\s*$', + line) + if m == None: raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,confidence") elif withTranscription: - m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,(.*)$',line) - if m == None : + m = re.match( + r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,(.*)$', + line) + if m == None: raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,transcription") else: - m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*$',line) - if m == None : + m = re.match( + r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*$', + line) + if m == None: raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4") - - points = [ float(m.group(i)) for i in range(1, (numPoints+1) ) ] - + + points = [float(m.group(i)) for i in range(1, (numPoints + 1))] + validate_clockwise_points(points) - - if (imWidth>0 and imHeight>0): - validate_point_inside_bounds(points[0],points[1],imWidth,imHeight) - validate_point_inside_bounds(points[2],points[3],imWidth,imHeight) - validate_point_inside_bounds(points[4],points[5],imWidth,imHeight) - validate_point_inside_bounds(points[6],points[7],imWidth,imHeight) - - + + if (imWidth > 0 and imHeight > 0): + validate_point_inside_bounds(points[0], points[1], imWidth, imHeight) + validate_point_inside_bounds(points[2], points[3], imWidth, imHeight) + validate_point_inside_bounds(points[4], points[5], imWidth, imHeight) + validate_point_inside_bounds(points[6], points[7], imWidth, imHeight) + if withConfidence: try: - confidence = float(m.group(numPoints+1)) + confidence = float(m.group(numPoints + 1)) except ValueError: - raise Exception("Confidence value must be a float") - + + raise Exception("Confidence value must be a float") + if withTranscription: posTranscription = numPoints + (2 if withConfidence else 1) transcription = m.group(posTranscription) - m2 = re.match(r'^\s*\"(.*)\"\s*$',transcription) - if m2 != None : #Transcription with double quotes, we extract the value and replace escaped characters + m2 = re.match(r'^\s*\"(.*)\"\s*$', transcription) + if m2 != None: # Transcription with double quotes, we extract the value and replace escaped characters transcription = m2.group(1).replace("\\\\", "\\").replace("\\\"", "\"") - - return points,confidence,transcription - - -def validate_point_inside_bounds(x,y,imWidth,imHeight): - if(x<0 or x>imWidth): - raise Exception("X value (%s) not valid. Image dimensions: (%s,%s)" %(x,imWidth,imHeight)) - if(y<0 or y>imHeight): - raise Exception("Y value (%s) not valid. Image dimensions: (%s,%s) Sample: %s Line:%s" %(y,imWidth,imHeight)) + + return points, confidence, transcription + + +def validate_point_inside_bounds(x, y, imWidth, imHeight): + if (x < 0 or x > imWidth): + raise Exception("X value (%s) not valid. Image dimensions: (%s,%s)" % (x, imWidth, imHeight)) + if (y < 0 or y > imHeight): + raise Exception( + "Y value (%s) not valid. Image dimensions: (%s,%s) Sample: %s Line:%s" % (y, imWidth, imHeight)) + def validate_clockwise_points(points): """ Validates that the points that the 4 points that dlimite a polygon are in clockwise order. """ - + if len(points) != 8: raise Exception("Points list not valid." + str(len(points))) - + point = [ - [int(points[0]) , int(points[1])], - [int(points[2]) , int(points[3])], - [int(points[4]) , int(points[5])], - [int(points[6]) , int(points[7])] - ] + [int(points[0]), int(points[1])], + [int(points[2]), int(points[3])], + [int(points[4]), int(points[5])], + [int(points[6]), int(points[7])] + ] edge = [ - ( point[1][0] - point[0][0])*( point[1][1] + point[0][1]), - ( point[2][0] - point[1][0])*( point[2][1] + point[1][1]), - ( point[3][0] - point[2][0])*( point[3][1] + point[2][1]), - ( point[0][0] - point[3][0])*( point[0][1] + point[3][1]) + (point[1][0] - point[0][0]) * (point[1][1] + point[0][1]), + (point[2][0] - point[1][0]) * (point[2][1] + point[1][1]), + (point[3][0] - point[2][0]) * (point[3][1] + point[2][1]), + (point[0][0] - point[3][0]) * (point[0][1] + point[3][1]) ] - + summatory = edge[0] + edge[1] + edge[2] + edge[3]; - if summatory>0: - raise Exception("Points are not clockwise. The coordinates of bounding quadrilaterals have to be given in clockwise order. Regarding the correct interpretation of 'clockwise' remember that the train_images coordinate system used is the standard one, with the train_images origin at the upper left, the X axis extending to the right and Y axis extending downwards.") + if summatory > 0: + raise Exception( + "Points are not clockwise. The coordinates of bounding quadrilaterals have to be given in clockwise order. Regarding the correct interpretation of 'clockwise' remember that the train_images coordinate system used is the standard one, with the train_images origin at the upper left, the X axis extending to the right and Y axis extending downwards.") + -def get_tl_line_values_from_file_contents(content,CRLF=True,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0,sort_by_confidences=True): +def get_tl_line_values_from_file_contents(content, CRLF=True, LTRB=True, withTranscription=False, withConfidence=False, + imWidth=0, imHeight=0, sort_by_confidences=True): """ Returns all points, confindences and transcriptions of a file in lists. Valid line formats: xmin,ymin,xmax,ymax,[confidence],[transcription] @@ -286,26 +313,28 @@ def get_tl_line_values_from_file_contents(content,CRLF=True,LTRB=True,withTransc pointsList = [] transcriptionsList = [] confidencesList = [] - - lines = content.split( "\r\n" if CRLF else "\n" ) + + lines = content.split("\r\n" if CRLF else "\n") for line in lines: - line = line.replace("\r","").replace("\n","") - if(line != "") : - points, confidence, transcription = get_tl_line_values(line,LTRB,withTranscription,withConfidence,imWidth,imHeight); + line = line.replace("\r", "").replace("\n", "") + if (line != ""): + points, confidence, transcription = get_tl_line_values(line, LTRB, withTranscription, withConfidence, + imWidth, imHeight); pointsList.append(points) transcriptionsList.append(transcription) confidencesList.append(confidence) - if withConfidence and len(confidencesList)>0 and sort_by_confidences: + if withConfidence and len(confidencesList) > 0 and sort_by_confidences: import numpy as np sorted_ind = np.argsort(-np.array(confidencesList)) confidencesList = [confidencesList[i] for i in sorted_ind] pointsList = [pointsList[i] for i in sorted_ind] - transcriptionsList = [transcriptionsList[i] for i in sorted_ind] - - return pointsList,confidencesList,transcriptionsList + transcriptionsList = [transcriptionsList[i] for i in sorted_ind] + + return pointsList, confidencesList, transcriptionsList -def main_evaluation(p,evalParams,evaluate_method_fn,show_result=True,per_sample=True): + +def main_evaluation(p, evalParams, evaluate_method_fn, show_result=True, per_sample=True): """ This process validates a method, evaluates it and if it succed generates a ZIP file with a JSON entry for each sample. Params: @@ -315,18 +344,18 @@ def main_evaluation(p,evalParams,evaluate_method_fn,show_result=True,per_sample= evaluate_method_fn: points to a function that evaluated the submission and return a Dictionary with the results """ if 'p' in p.keys(): - evalParams.update( p['p'] if isinstance(p['p'], dict) else json.loads(p['p'][1:-1]) ) + evalParams.update(p['p'] if isinstance(p['p'], dict) else json.loads(p['p'][1:-1])) - resDict={'calculated':True,'Message':'','method':'{}','per_sample':'{}'} + resDict = {'calculated': True, 'Message': '', 'method': '{}', 'per_sample': '{}'} try: # validate_data_fn(p['g'], p['s'], evalParams) evalData = evaluate_method_fn(p['g'], p['s'], evalParams) resDict.update(evalData) - + except Exception as e: traceback.print_exc() - resDict['Message']= str(e) - resDict['calculated']=False + resDict['Message'] = str(e) + resDict['calculated'] = False if 'o' in p: if not os.path.exists(p['o']): @@ -339,49 +368,49 @@ def main_evaluation(p,evalParams,evaluate_method_fn,show_result=True,per_sample= if 'output_items' in resDict.keys(): del resDict['output_items'] - outZip.writestr('method.json',json.dumps(resDict)) - + outZip.writestr('method.json', json.dumps(resDict)) + if not resDict['calculated']: if show_result: - sys.stderr.write('Error!\n'+ resDict['Message']+'\n\n') + sys.stderr.write('Error!\n' + resDict['Message'] + '\n\n') if 'o' in p: outZip.close() return resDict - + if 'o' in p: if per_sample == True: - for k,v in evalData['per_sample'].iteritems(): - outZip.writestr( k + '.json',json.dumps(v)) + for k, v in evalData['per_sample'].iteritems(): + outZip.writestr(k + '.json', json.dumps(v)) if 'output_items' in evalData.keys(): for k, v in evalData['output_items'].iteritems(): - outZip.writestr( k,v) + outZip.writestr(k, v) outZip.close() if show_result: sys.stdout.write("Calculated!") sys.stdout.write(json.dumps(resDict['method'])) - + return resDict -def main_validation(default_evaluation_params_fn,validate_data_fn): +def main_validation(default_evaluation_params_fn, validate_data_fn): """ This process validates a method Params: default_evaluation_params_fn: points to a function that returns a dictionary with the default parameters used for the evaluation validate_data_fn: points to a method that validates the corrct format of the submission - """ + """ try: p = dict([s[1:].split('=') for s in sys.argv[1:]]) evalParams = default_evaluation_params_fn() if 'p' in p.keys(): - evalParams.update( p['p'] if isinstance(p['p'], dict) else json.loads(p['p'][1:-1]) ) + evalParams.update(p['p'] if isinstance(p['p'], dict) else json.loads(p['p'][1:-1])) - validate_data_fn(p['g'], p['s'], evalParams) + validate_data_fn(p['g'], p['s'], evalParams) print('SUCCESS') sys.exit(0) except Exception as e: print(str(e)) - sys.exit(101) \ No newline at end of file + sys.exit(101) diff --git a/models/text_detection_db/cal_rescall/script.py b/models/text_detection_db/cal_rescall/script.py index 06039b5f..227a05ec 100644 --- a/models/text_detection_db/cal_rescall/script.py +++ b/models/text_detection_db/cal_rescall/script.py @@ -1,4 +1,6 @@ # -*- coding: utf-8 -*- +# https://rrc.cvc.uab.es/?ch=2&com=mymethods&task=1 + from collections import namedtuple from . import rrc_evaluation_funcs import Polygon as plg @@ -132,12 +134,11 @@ def compute_ap(confList, matchList, numGtCare): evaluationLog = "" - pointsList, _, transcriptionsList = rrc_evaluation_funcs.get_tl_line_values_from_file_contents(gtFile, - evaluationParams[ - 'CRLF'], - evaluationParams[ - 'LTRB'], - True, False) + pointsList, _, transcriptionsList = rrc_evaluation_funcs. \ + get_tl_line_values_from_file_contents(gtFile, + evaluationParams['CRLF'], \ + evaluationParams['LTRB'], \ + True, False) for n in range(len(pointsList)): points = pointsList[n] transcription = transcriptionsList[n] @@ -159,14 +160,12 @@ def compute_ap(confList, matchList, numGtCare): detFile = subm[resFile] # rrc_evaluation_funcs.decode_utf8(subm[resFile]) - pointsList, confidencesList, _ = rrc_evaluation_funcs.get_tl_line_values_from_file_contents(detFile, - evaluationParams[ - 'CRLF'], - evaluationParams[ - 'LTRB'], - False, - evaluationParams[ - 'CONFIDENCES']) + pointsList, confidencesList, _ = rrc_evaluation_funcs. \ + get_tl_line_values_from_file_contents(detFile, \ + evaluationParams['CRLF'], \ + evaluationParams['LTRB'], \ + False, \ + evaluationParams['CONFIDENCES']) for n in range(len(pointsList)): points = pointsList[n] @@ -281,4 +280,4 @@ def cal_recall_precison_f1(gt_path, result_path, default_evaluation_params, show p = {'g': gt_path, 's': result_path} result = rrc_evaluation_funcs.main_evaluation(p, default_evaluation_params, evaluate_method, show_result) - return result['method'] \ No newline at end of file + return result['method'] diff --git a/models/text_detection_db/validation.py b/models/text_detection_db/validation.py index de272ddf..a5aabcb3 100644 --- a/models/text_detection_db/validation.py +++ b/models/text_detection_db/validation.py @@ -1,8 +1,6 @@ # This file is part of OpenCV Zoo project. # It is subject to the license terms in the LICENSE file found in the same directory. # -# Copyright (C) 2021, Shenzhen Institute of Artificial Intelligence and Robotics for Society, all rights reserved. -# Third party copyrights are property of their respective owners. import argparse import glob @@ -24,6 +22,7 @@ def str2bool(v): else: raise NotImplementedError + backends = [cv.dnn.DNN_BACKEND_OPENCV, cv.dnn.DNN_BACKEND_CUDA] targets = [cv.dnn.DNN_TARGET_CPU, cv.dnn.DNN_TARGET_CUDA, cv.dnn.DNN_TARGET_CUDA_FP16] help_msg_backends = "Choose one of the computation backends: {:d}: OpenCV implementation (default); {:d}: CUDA" @@ -34,11 +33,14 @@ def str2bool(v): help_msg_backends += "; {:d}: TIMVX" help_msg_targets += "; {:d}: NPU" except: - print('This version of OpenCV does not support TIM-VX and NPU. Visit https://gist.github.com/fengyuentau/5a7a5ba36328f2b763aea026c43fa45f for more information.') + print( + 'This version of OpenCV does not support TIM-VX and NPU. Visit https://gist.github.com/fengyuentau/5a7a5ba36328f2b763aea026c43fa45f for more information.') -parser = argparse.ArgumentParser(description='Real-time Scene Text Detection with Differentiable Binarization (https://arxiv.org/abs/1911.08947).') +parser = argparse.ArgumentParser( + description='Real-time Scene Text Detection with Differentiable Binarization (https://arxiv.org/abs/1911.08947).') parser.add_argument('--input', '-i', type=str, help='Path to the input image. Omit for using default camera.') -parser.add_argument('--model', '-m', type=str, default='text_detection_DB_TD500_resnet18_2021sep.onnx', help='Path to the model.') +parser.add_argument('--model', '-m', type=str, default='text_detection_DB_TD500_resnet18_2021sep.onnx', + help='Path to the model.') parser.add_argument('--gt_dir', type=str, default='icdar2015/test_gts', help='Path to the ground truth txt directory.') parser.add_argument('--out_dir', type=str, default='icdar2015/test_predicts', help='Path to the output txt directory.') parser.add_argument('--img_dir', type=str, default='icdar2015/test_images', help='Path to the test images directory.') @@ -53,15 +55,16 @@ def str2bool(v): parser.add_argument('--iou_constraint', type=float, default=0.3, help='IOU constraint.') parser.add_argument('--area_precision_constraint', type=float, default=0.8, help='Area precision constraint.') parser.add_argument('--max_candidates', type=int, default=200, help='Max candidates of polygons.') -parser.add_argument('--unclip_ratio', type=np.float64, default=2.0, help=' The unclip ratio of the detected text region, which determines the output size.') -parser.add_argument('--save', '-s', type=str, default=False, help='Set true to save results. This flag is invalid when using camera.') -parser.add_argument('--vis', '-v', type=str2bool, default=True, help='Set true to open a window for result visualization. This flag is invalid when using camera.') +parser.add_argument('--unclip_ratio', type=np.float64, default=2.0, + help=' The unclip ratio of the detected text region, which determines the output size.') +parser.add_argument('--save', '-s', type=str, default=False, + help='Set true to save results. This flag is invalid when using camera.') +parser.add_argument('--vis', '-v', type=str2bool, default=True, + help='Set true to open a window for result visualization. This flag is invalid when using camera.') args = parser.parse_args() - - if __name__ == '__main__': - default_evaluation_params={ + default_evaluation_params = { 'IOU_CONSTRAINT': args.iou_constraint, 'AREA_PRECISION_CONSTRAINT': args.area_precision_constraint, 'GT_SAMPLE_NAME_2_ID': 'gt_img_([0-9]+).txt', @@ -69,7 +72,7 @@ def str2bool(v): 'LTRB': False, # LTRB:2points(left,top,right,bottom) or 4 points(x1,y1,x2,y2,x3,y3,x4,y4) 'CRLF': False, # Lines are delimited by Windows CRLF format 'CONFIDENCES': True, # Detections must include confidence value. AP will be calculated - 'PER_SAMPLE_RESULTS':False # Generate per sample results and produce data for visualization + 'PER_SAMPLE_RESULTS': False # Generate per sample results and produce data for visualization } # Instantiate DB model = DB(modelPath=args.model, @@ -80,9 +83,9 @@ def str2bool(v): unclipRatio=args.unclip_ratio, backendId=args.backend, targetId=args.target - ) - files = glob.glob(args.img_dir+'/*', recursive=True) - tot_time=0 + ) + files = glob.glob(args.img_dir + '/*', recursive=True) + tot_time = 0 for file in files: image = cv.imread(file) @@ -92,19 +95,21 @@ def str2bool(v): # Inference results = model.infer(image) end = time.time() - tot_time+=(end-start)/len(files) - img_name=file.split('/')[-1].split('.')[0] - text_file = args.out_dir+'/res_' + img_name + '.txt' - result='' + tot_time += (end - start) / len(files) + img_name = file.split('/')[-1].split('.')[0] + text_file = args.out_dir + '/res_' + img_name + '.txt' + result = '' for idx, (bbox, score) in enumerate(zip(results[0], results[1])): - result+='{},{},{},{},{},{},{},{},{}\n'.format(int(bbox[0][0]*scale[0]),int(bbox[0][1]*scale[1]), int(bbox[1][0]*scale[0]), int(bbox[1][1]*scale[1]), int(bbox[2][0]*scale[0]),int(bbox[2][1]*scale[1]), int(bbox[3][0]*scale[0]), int(bbox[3][1]*scale[1]),score) + result += '{},{},{},{},{},{},{},{},{}\n'.format(int(bbox[0][0] * scale[0]), int(bbox[0][1] * scale[1]), + int(bbox[1][0] * scale[0]), int(bbox[1][1] * scale[1]), + int(bbox[2][0] * scale[0]), int(bbox[2][1] * scale[1]), + int(bbox[3][0] * scale[0]), int(bbox[3][1] * scale[1]), + score) with open(text_file, 'w+') as fid: fid.write(result) - avg_time=tot_time/len(files) - result_dict = cal_recall_precison_f1(args.gt_dir,args.out_dir,default_evaluation_params) + avg_time = tot_time / len(files) + result_dict = cal_recall_precison_f1(args.gt_dir, args.out_dir, default_evaluation_params) # Print results - result_dict['avg time']=avg_time + result_dict['avg time'] = avg_time print(result_dict) - - From ee5f0dfba662654ed6e51974ebd85c5ea8652da1 Mon Sep 17 00:00:00 2001 From: stone <11911611@mail.sustech.edu.cn> Date: Thu, 8 Sep 2022 21:20:46 -0700 Subject: [PATCH 5/5] Merge into benchmark draft --- benchmark/benchmark.py | 23 ++-- .../cal_rescall/rrc_evaluation_funcs.py | 2 +- .../cal_rescall/script.py | 0 .../config/text_detection_db_metrics.yaml | 21 ++++ benchmark/detection_metric.py | 87 +++++++++++++ models/text_detection_db/validation.py | 115 ------------------ 6 files changed, 123 insertions(+), 125 deletions(-) rename {models/text_detection_db => benchmark}/cal_rescall/rrc_evaluation_funcs.py (99%) rename {models/text_detection_db => benchmark}/cal_rescall/script.py (100%) create mode 100644 benchmark/config/text_detection_db_metrics.yaml create mode 100644 benchmark/detection_metric.py delete mode 100644 models/text_detection_db/validation.py diff --git a/benchmark/benchmark.py b/benchmark/benchmark.py index 5ec2849d..8e658144 100644 --- a/benchmark/benchmark.py +++ b/benchmark/benchmark.py @@ -1,10 +1,11 @@ import os import argparse - +# import sys +# sys.path.append('/') import yaml import numpy as np import cv2 as cv - +from detection_metric import detect_metric from models import MODELS from utils import METRICS, DATALOADERS @@ -118,13 +119,17 @@ def printResults(self): # prepend PYTHONPATH to each path prepend_pythonpath(cfg) - # Instantiate benchmarking - benchmark = Benchmark(**cfg['Benchmark']) - # Instantiate model model = build_from_cfg(cfg=cfg['Model'], registery=MODELS, key='name') - - # Run benchmarking print('Benchmarking {}:'.format(model.name)) - benchmark.run(model) - benchmark.printResults() + + type = cfg['Benchmark']['type'] + if type=="Detection metrics": + detect_metric(model,**cfg['Benchmark']) + else: + # Instantiate benchmarking + benchmark = Benchmark(**cfg['Benchmark']) + + # Run benchmarking + benchmark.run(model) + benchmark.printResults() diff --git a/models/text_detection_db/cal_rescall/rrc_evaluation_funcs.py b/benchmark/cal_rescall/rrc_evaluation_funcs.py similarity index 99% rename from models/text_detection_db/cal_rescall/rrc_evaluation_funcs.py rename to benchmark/cal_rescall/rrc_evaluation_funcs.py index b31576a3..656a2ff9 100644 --- a/models/text_detection_db/cal_rescall/rrc_evaluation_funcs.py +++ b/benchmark/cal_rescall/rrc_evaluation_funcs.py @@ -4,7 +4,7 @@ import json import sys -sys.path.append('./') +sys.path.append('/') import zipfile import re import sys diff --git a/models/text_detection_db/cal_rescall/script.py b/benchmark/cal_rescall/script.py similarity index 100% rename from models/text_detection_db/cal_rescall/script.py rename to benchmark/cal_rescall/script.py diff --git a/benchmark/config/text_detection_db_metrics.yaml b/benchmark/config/text_detection_db_metrics.yaml new file mode 100644 index 00000000..9186a3e1 --- /dev/null +++ b/benchmark/config/text_detection_db_metrics.yaml @@ -0,0 +1,21 @@ +Benchmark: + name: "Text Detection metrics Benchmark" + type: "Detection metrics" + data: + gt: "./data/icdar2015/gt" + imgs: "./data/icdar2015/img" + out: "./data/icdar2015/out" + iou_constraint: 0.3 + area_precision_constraint: 0.8 + height: 736 + width: 736 + backend: "default" + target: "cpu" + +Model: + name: "DB" + modelPath: "models/text_detection_db/text_detection_DB_TD500_resnet18_2021sep.onnx" + binaryThreshold: 0.3 + polygonThreshold: 0.5 + maxCandidates: 200 + unclipRatio: 2.0 \ No newline at end of file diff --git a/benchmark/detection_metric.py b/benchmark/detection_metric.py new file mode 100644 index 00000000..5ed9c47f --- /dev/null +++ b/benchmark/detection_metric.py @@ -0,0 +1,87 @@ +# This file is part of OpenCV Zoo project. +# It is subject to the license terms in the LICENSE file found in the same directory. +# + +import argparse +import glob +import time + +import numpy as np +import cv2 as cv +from cal_rescall.script import cal_recall_precison_f1 + + + + +def detect_metric(model,**kwargs): + backend_id = kwargs.pop('backend', 'default') + data=kwargs['data'] + img_dir=data.pop('imgs') + gt_dir=data.pop('gt') + out_dir=data.pop('out') + width=data.pop('width') + height=data.pop('height') + iou_constraint=data.pop("iou_constraint") + area_precision_constraint=data.pop("area_precision_constraint") + available_backends = dict( + default=cv.dnn.DNN_BACKEND_DEFAULT, + # halide=cv.dnn.DNN_BACKEND_HALIDE, + # inference_engine=cv.dnn.DNN_BACKEND_INFERENCE_ENGINE, + opencv=cv.dnn.DNN_BACKEND_OPENCV, + # vkcom=cv.dnn.DNN_BACKEND_VKCOM, + cuda=cv.dnn.DNN_BACKEND_CUDA, + ) + + target_id = kwargs.pop('target', 'cpu') + available_targets = dict( + cpu=cv.dnn.DNN_TARGET_CPU, + cuda=cv.dnn.DNN_TARGET_CUDA, + cuda_fp16=cv.dnn.DNN_TARGET_CUDA_FP16, + ) + + # add extra backends & targets + try: + available_backends['timvx'] = cv.dnn.DNN_BACKEND_TIMVX + available_targets['npu'] = cv.dnn.DNN_TARGET_NPU + except: + print( + 'OpenCV is not compiled with TIM-VX backend enbaled. See https://github.com/opencv/opencv/wiki/TIM-VX-Backend-For-Running-OpenCV-On-NPU for more details on how to enable TIM-VX backend.') + + _backend = available_backends[backend_id] + _target = available_targets[target_id] + model.setBackend(_backend) + model.setTarget(_target) + default_evaluation_params = { + 'IOU_CONSTRAINT': iou_constraint, + 'AREA_PRECISION_CONSTRAINT': area_precision_constraint, + 'GT_SAMPLE_NAME_2_ID': 'gt_img_([0-9]+).txt', + 'DET_SAMPLE_NAME_2_ID': 'res_img_([0-9]+).txt', + 'LTRB': False, # LTRB:2points(left,top,right,bottom) or 4 points(x1,y1,x2,y2,x3,y3,x4,y4) + 'CRLF': False, # Lines are delimited by Windows CRLF format + 'CONFIDENCES': True, # Detections must include confidence value. AP will be calculated + 'PER_SAMPLE_RESULTS': False # Generate per sample results and produce data for visualization + } + # Instantiate DB + + files = glob.glob(img_dir + '/*', recursive=True) + + for file in files: + image = cv.imread(file) + scale = (image.shape[1] * 1.0 / width, image.shape[0] * 1.0 / height) + image = cv.resize(image, [width, height]) + # Inference + results = model.infer(image) + img_name = file.split('/')[-1].split('.')[0] + text_file = out_dir + '/res_' + img_name + '.txt' + result = '' + for idx, (bbox, score) in enumerate(zip(results[0], results[1])): + result += '{},{},{},{},{},{},{},{},{}\n'.format(int(bbox[0][0] * scale[0]), int(bbox[0][1] * scale[1]), + int(bbox[1][0] * scale[0]), int(bbox[1][1] * scale[1]), + int(bbox[2][0] * scale[0]), int(bbox[2][1] * scale[1]), + int(bbox[3][0] * scale[0]), int(bbox[3][1] * scale[1]), + score) + with open(text_file, 'w+') as fid: + fid.write(result) + result_dict = cal_recall_precison_f1(gt_dir, out_dir, default_evaluation_params) + print(result_dict) + return result_dict diff --git a/models/text_detection_db/validation.py b/models/text_detection_db/validation.py deleted file mode 100644 index a5aabcb3..00000000 --- a/models/text_detection_db/validation.py +++ /dev/null @@ -1,115 +0,0 @@ -# This file is part of OpenCV Zoo project. -# It is subject to the license terms in the LICENSE file found in the same directory. -# - -import argparse -import glob -import sys -import time - -import numpy as np -import cv2 as cv - -from db import DB -from cal_rescall.script import cal_recall_precison_f1 - - -def str2bool(v): - if v.lower() in ['on', 'yes', 'true', 'y', 't']: - return True - elif v.lower() in ['off', 'no', 'false', 'n', 'f']: - return False - else: - raise NotImplementedError - - -backends = [cv.dnn.DNN_BACKEND_OPENCV, cv.dnn.DNN_BACKEND_CUDA] -targets = [cv.dnn.DNN_TARGET_CPU, cv.dnn.DNN_TARGET_CUDA, cv.dnn.DNN_TARGET_CUDA_FP16] -help_msg_backends = "Choose one of the computation backends: {:d}: OpenCV implementation (default); {:d}: CUDA" -help_msg_targets = "Chose one of the target computation devices: {:d}: CPU (default); {:d}: CUDA; {:d}: CUDA fp16" -try: - backends += [cv.dnn.DNN_BACKEND_TIMVX] - targets += [cv.dnn.DNN_TARGET_NPU] - help_msg_backends += "; {:d}: TIMVX" - help_msg_targets += "; {:d}: NPU" -except: - print( - 'This version of OpenCV does not support TIM-VX and NPU. Visit https://gist.github.com/fengyuentau/5a7a5ba36328f2b763aea026c43fa45f for more information.') - -parser = argparse.ArgumentParser( - description='Real-time Scene Text Detection with Differentiable Binarization (https://arxiv.org/abs/1911.08947).') -parser.add_argument('--input', '-i', type=str, help='Path to the input image. Omit for using default camera.') -parser.add_argument('--model', '-m', type=str, default='text_detection_DB_TD500_resnet18_2021sep.onnx', - help='Path to the model.') -parser.add_argument('--gt_dir', type=str, default='icdar2015/test_gts', help='Path to the ground truth txt directory.') -parser.add_argument('--out_dir', type=str, default='icdar2015/test_predicts', help='Path to the output txt directory.') -parser.add_argument('--img_dir', type=str, default='icdar2015/test_images', help='Path to the test images directory.') -parser.add_argument('--width', type=int, default=736, - help='Preprocess input image by resizing to a specific width. It should be multiple by 32.') -parser.add_argument('--height', type=int, default=736, - help='Preprocess input image by resizing to a specific height. It should be multiple by 32.') -parser.add_argument('--backend', '-b', type=int, default=backends[0], help=help_msg_backends.format(*backends)) -parser.add_argument('--target', '-t', type=int, default=targets[0], help=help_msg_targets.format(*targets)) -parser.add_argument('--binary_threshold', type=float, default=0.3, help='Threshold of the binary map.') -parser.add_argument('--polygon_threshold', type=float, default=0.5, help='Threshold of polygons.') -parser.add_argument('--iou_constraint', type=float, default=0.3, help='IOU constraint.') -parser.add_argument('--area_precision_constraint', type=float, default=0.8, help='Area precision constraint.') -parser.add_argument('--max_candidates', type=int, default=200, help='Max candidates of polygons.') -parser.add_argument('--unclip_ratio', type=np.float64, default=2.0, - help=' The unclip ratio of the detected text region, which determines the output size.') -parser.add_argument('--save', '-s', type=str, default=False, - help='Set true to save results. This flag is invalid when using camera.') -parser.add_argument('--vis', '-v', type=str2bool, default=True, - help='Set true to open a window for result visualization. This flag is invalid when using camera.') -args = parser.parse_args() - -if __name__ == '__main__': - default_evaluation_params = { - 'IOU_CONSTRAINT': args.iou_constraint, - 'AREA_PRECISION_CONSTRAINT': args.area_precision_constraint, - 'GT_SAMPLE_NAME_2_ID': 'gt_img_([0-9]+).txt', - 'DET_SAMPLE_NAME_2_ID': 'res_img_([0-9]+).txt', - 'LTRB': False, # LTRB:2points(left,top,right,bottom) or 4 points(x1,y1,x2,y2,x3,y3,x4,y4) - 'CRLF': False, # Lines are delimited by Windows CRLF format - 'CONFIDENCES': True, # Detections must include confidence value. AP will be calculated - 'PER_SAMPLE_RESULTS': False # Generate per sample results and produce data for visualization - } - # Instantiate DB - model = DB(modelPath=args.model, - inputSize=[args.width, args.height], - binaryThreshold=args.binary_threshold, - polygonThreshold=args.polygon_threshold, - maxCandidates=args.max_candidates, - unclipRatio=args.unclip_ratio, - backendId=args.backend, - targetId=args.target - ) - files = glob.glob(args.img_dir + '/*', recursive=True) - tot_time = 0 - - for file in files: - image = cv.imread(file) - scale = (image.shape[1] * 1.0 / args.width, image.shape[0] * 1.0 / args.height) - image = cv.resize(image, [args.width, args.height]) - start = time.time() - # Inference - results = model.infer(image) - end = time.time() - tot_time += (end - start) / len(files) - img_name = file.split('/')[-1].split('.')[0] - text_file = args.out_dir + '/res_' + img_name + '.txt' - result = '' - for idx, (bbox, score) in enumerate(zip(results[0], results[1])): - result += '{},{},{},{},{},{},{},{},{}\n'.format(int(bbox[0][0] * scale[0]), int(bbox[0][1] * scale[1]), - int(bbox[1][0] * scale[0]), int(bbox[1][1] * scale[1]), - int(bbox[2][0] * scale[0]), int(bbox[2][1] * scale[1]), - int(bbox[3][0] * scale[0]), int(bbox[3][1] * scale[1]), - score) - with open(text_file, 'w+') as fid: - fid.write(result) - - avg_time = tot_time / len(files) - result_dict = cal_recall_precison_f1(args.gt_dir, args.out_dir, default_evaluation_params) - # Print results - result_dict['avg time'] = avg_time - print(result_dict)