Skip to content

Commit

Permalink
Merge pull request #5222 from WenmuZhou/update_whl_dygraph
Browse files Browse the repository at this point in the history
rm model type check
  • Loading branch information
WenmuZhou authored Jan 10, 2022
2 parents 5a537bb + c0ce890 commit c910bf8
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions paddleocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@
BASE_DIR = os.path.expanduser("~/.paddleocr/")

DEFAULT_OCR_MODEL_VERSION = 'PP-OCR'
SUPPORT_OCR_MODEL_VERSION = ['PP-OCR', 'PP-OCRv2']
DEFAULT_STRUCTURE_MODEL_VERSION = 'STRUCTURE'
SUPPORT_STRUCTURE_MODEL_VERSION = ['STRUCTURE']
MODEL_URLS = {
'OCR': {
'PP-OCRv2': {
Expand Down Expand Up @@ -190,6 +192,7 @@ def parse_args(mMain=True):
parser.add_argument(
"--ocr_version",
type=str,
choices=SUPPORT_OCR_MODEL_VERSION,
default='PP-OCRv2',
help='OCR Model version, the current model support list is as follows: '
'1. PP-OCRv2 Support Chinese detection and recognition model. '
Expand All @@ -198,6 +201,7 @@ def parse_args(mMain=True):
parser.add_argument(
"--structure_version",
type=str,
choices=SUPPORT_STRUCTURE_MODEL_VERSION,
default='STRUCTURE',
help='Model version, the current model support list is as follows:'
' 1. STRUCTURE Support en table structure model.')
Expand Down Expand Up @@ -257,26 +261,20 @@ def get_model_config(type, version, model_type, lang):
DEFAULT_MODEL_VERSION = DEFAULT_STRUCTURE_MODEL_VERSION
else:
raise NotImplementedError

model_urls = MODEL_URLS[type]
if version not in model_urls:
logger.warning('version {} not in {}, auto switch to version {}'.format(
version, model_urls.keys(), DEFAULT_MODEL_VERSION))
version = DEFAULT_MODEL_VERSION
if model_type not in model_urls[version]:
if model_type in model_urls[DEFAULT_MODEL_VERSION]:
logger.warning(
'version {} not support {} models, auto switch to version {}'.
format(version, model_type, DEFAULT_MODEL_VERSION))
version = DEFAULT_MODEL_VERSION
else:
logger.error('{} models is not support, we only support {}'.format(
model_type, model_urls[DEFAULT_MODEL_VERSION].keys()))
sys.exit(-1)

if lang not in model_urls[version][model_type]:
if lang in model_urls[DEFAULT_MODEL_VERSION][model_type]:
logger.warning(
'lang {} is not support in {}, auto switch to version {}'.
format(lang, version, DEFAULT_MODEL_VERSION))
version = DEFAULT_MODEL_VERSION
else:
logger.error(
Expand All @@ -296,6 +294,8 @@ def __init__(self, **kwargs):
"""
params = parse_args(mMain=False)
params.__dict__.update(**kwargs)
assert params.ocr_version in SUPPORT_OCR_MODEL_VERSION, "ocr_version must in {}, but get {}".format(
SUPPORT_OCR_MODEL_VERSION, params.ocr_version)
params.use_gpu = check_gpu(params.use_gpu)

if not params.show_log:
Expand Down Expand Up @@ -398,6 +398,8 @@ class PPStructure(OCRSystem):
def __init__(self, **kwargs):
params = parse_args(mMain=False)
params.__dict__.update(**kwargs)
assert params.structure_version in SUPPORT_STRUCTURE_MODEL_VERSION, "structure_version must in {}, but get {}".format(
SUPPORT_STRUCTURE_MODEL_VERSION, params.structure_version)
params.use_gpu = check_gpu(params.use_gpu)

if not params.show_log:
Expand Down

0 comments on commit c910bf8

Please sign in to comment.