Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[eval_hooks.py] TypeError: '>' not supported between instances of 'Nonetype' and 'float' #7096

Closed
huang-jesse opened this issue Jan 29, 2022 · 3 comments
Assignees
Labels
bug Something isn't working

Comments

@huang-jesse
Copy link
Contributor

huang-jesse commented Jan 29, 2022

Checklist

  1. I have searched related issues but cannot get the expected help. I think i got the same issue with ERROR - The testing results of the whole dataset is empty - YOLOX and COCO that fixed for mmcv, but doesn't resovled in eval_hooks.py
  2. I have read the FAQ documentation but cannot get the expected help.
  3. The bug has not been fixed in the latest version.

Describe the bug

I think i got the same issue with ERROR - The testing results of the whole dataset is empty - YOLOX and COCO that fixed for mmcv, but doesn't resovled in eval_hooks.py.

Reproduction

  1. I just train with yolox-l and coco, here is the configs below:
_base_ = './yolox_s_8x8_300e_coco.py'
img_scale = (1280, 1280)
# model settings
model = dict(
    input_size=img_scale,
    backbone=dict(deepen_factor=1.0, widen_factor=1.0),
    neck=dict(
        in_channels=[256, 512, 1024], out_channels=256, num_csp_blocks=3),
    bbox_head=dict(
        type='YOLOXHead', num_classes=1, in_channels=256, feat_channels=256),)

# Modify dataset related settings
data_root = 'data/tgbr_tile/'
dataset_type = 'CocoDataset'
classes = ('my_class',)

# Modify pipeline
albu_train_transforms = [
    # dict(type='RandomRotate90', always_apply=False, p=0.5),
    dict(
        type='OneOf',
        transforms=[
            dict(type='Blur', blur_limit=3, p=1.0),
            dict(type='MedianBlur', blur_limit=3, p=1.0)
        ],
        p=0.3),
    dict(
        type='RandomBrightnessContrast',
        brightness_limit=[-0.3, 0.3],
        contrast_limit=[-0.3, 0.3],
        p=0.3),
]
train_pipeline = [
    dict(type='Rotate', level=10, prob=0.5, max_rotate_angle=90),
    dict(type='Albu',
         transforms=albu_train_transforms,
         bbox_params=dict(type='BboxParams',
                          format='pascal_voc',
                          label_fields=['gt_labels'],
                          min_visibility=0.0,
                          filter_lost_elements=True),
         keymap={'img': 'image', 'gt_bboxes': 'bboxes'},
         update_pad_shape=False,
         skip_img_without_anno=False),
    dict(
        type='RandomAffine',
        max_rotate_degree=0,
        scaling_ratio_range=(0.5, 1.5),
        border=(0, 0)),
    dict(
        type='CutOut',
        n_holes=(0, 3),
        cutout_shape=[(20, 20), (30, 30)]),
    dict(
        type='MinIoURandomCrop',
        min_ious=(0.7, 0.9,),
        min_crop_size=0.2),
    dict(type='CLAHE_or_HE_HSV', prob=0.7),
    dict(type='YOLOXHSVRandomAug'),
    dict(type='RandomFlip', flip_ratio=0.5),
    # According to the official implementation, multi-scale
    # training is not considered here but in the
    # 'mmdet/models/detectors/yolox.py'.
    # dict(type='Resize', img_scale=img_scale, keep_ratio=True),
    dict(type='Resize',
         img_scale=img_scale,
         keep_ratio=True
         ),
    dict(
        type='Pad',
        pad_to_square=True,
        # If the image is three-channel, the pad value needs
        # to be set separately for each channel.
        pad_val=dict(img=(114.0, 114.0, 114.0))),
    dict(type='FilterAnnotations', min_gt_bbox_wh=(1, 1), keep_empty=False),
    dict(type='DefaultFormatBundle'),
    dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
]

train_dataset = dict(
    type='MultiImageMixDataset',
    dataset=dict(
        type=dataset_type,
        img_prefix='train_images/',
        ann_file='labels/coco_train.json',
        classes=classes,
        data_root=data_root,
        pipeline=[
            dict(type='LoadImageFromFile'),
            dict(type='LoadAnnotations', with_bbox=True)
        ],
        filter_empty_gt=False,
    ),
    pipeline=train_pipeline)

test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(
        type='MultiScaleFlipAug',
        img_scale=img_scale,
        flip=False,
        transforms=[
            dict(type='Resize',
                 keep_ratio=True),
            dict(type='RandomFlip'),
            dict(
                type='Pad',
                pad_to_square=True,
                pad_val=dict(img=(114.0, 114.0, 114.0))),
            dict(type='DefaultFormatBundle'),
            dict(type='Collect', keys=['img'])
        ])
]

data = dict(
    samples_per_gpu=2,
    workers_per_gpu=4,
    train=train_dataset,
    val=dict(
        type=dataset_type,
        img_prefix='train_images/',
        ann_file='labels/coco_val.json',
        classes=classes,
        data_root=data_root,
        pipeline=test_pipeline),
    test=dict(
        type=dataset_type,
        img_prefix='train_images/',
        ann_file='labels/coco_val.json',
        classes=classes,
        data_root=data_root,
        pipeline=test_pipeline))

# optimizer
# default 8 gpu because we only 1 gpu so lr = (0.01/8)
optimizer = dict(
    type='SGD',
    lr=(0.01/(8*2*2)),
    momentum=0.9,
    weight_decay=5e-4,
    nesterov=True,
    paramwise_cfg=dict(norm_decay_mult=0., bias_decay_mult=0.))
optimizer_config = dict(grad_clip=None)

# the max_epochs and step in lr_config need specifically tuned for the customized dataset
max_epochs = 100
num_last_epochs = 5
resume_from = None
interval = 2

# learning policy
lr_config = dict(
    _delete_=True,
    policy='YOLOX',
    warmup='exp',
    by_epoch=False,
    warmup_by_epoch=True,
    warmup_ratio=1,
    warmup_iters=5,  # 5 epoch
    num_last_epochs=num_last_epochs,
    min_lr_ratio=0.05)

runner = dict(type='EpochBasedRunner', max_epochs=max_epochs)

custom_hooks = [
    dict(
        type='YOLOXModeSwitchHook',
        num_last_epochs=num_last_epochs,
        priority=48),
    dict(
        type='SyncNormHook',
        num_last_epochs=num_last_epochs,
        interval=interval,
        priority=48),
    dict(
        type='ExpMomentumEMAHook',
        resume_from=resume_from,
        momentum=0.0001,
        priority=49)
]
checkpoint_config = dict(interval=interval)
evaluation = dict(
    save_best='auto',
    # The evaluation interval is 'interval' when running epoch is
    # less than ‘max_epochs - num_last_epochs’.
    # The evaluation interval is 1 when running epoch is greater than
    # or equal to ‘max_epochs - num_last_epochs’.
    interval=interval,
    dynamic_intervals=[(max_epochs - num_last_epochs, 1)],
    metric='bbox')
log_config = dict(interval=40)
# We can use the pre-trained model to obtain higher performance
load_from = 'checkpoints/yolox_l_8x8_300e_coco_20211126_140236-d3bd2b23.pth'

Environment

necessary environment information:

sys.platform: linux
Python: 3.9.7 (default, Sep 16 2021, 13:09:58) [GCC 7.5.0]
CUDA available: True
GPU 0: Quadro RTX 8000
CUDA_HOME: /usr/local/cuda
NVCC: Cuda compilation tools, release 10.2, V10.2.89
GCC: gcc (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
PyTorch: 1.10.1
PyTorch compiling details: PyTorch built with:
  - GCC 7.3
  - C++ Version: 201402
  - Intel(R) oneAPI Math Kernel Library Version 2021.4-Product Build 20210904 for Intel(R) 64 architecture applications
  - Intel(R) MKL-DNN v2.2.3 (Git Hash 7336ca9f055cf1bfa13efb658fe15dc9b41f0740)
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - LAPACK is enabled (usually provided by MKL)
  - NNPACK is enabled
  - CPU capability usage: AVX2
  - CUDA Runtime 10.2
  - NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_61,code=sm_61;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_37,code=compute_37
  - CuDNN 7.6.5
  - Magma 2.5.2
  - Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=10.2, CUDNN_VERSION=7.6.5, CXX_COMPILER=/opt/rh/devtoolset-7/root/usr/bin/c++, CXX_FLAGS= -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_KINETO -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -DEDGE_PROFILER_USE_KINETO -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-sign-compare -Wno-unused-parameter -Wno-unused-variable -Wno-unused-function -Wno-unused-result -Wno-unused-local-typedefs -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=1.10.1, USE_CUDA=ON, USE_CUDNN=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=ON, USE_OPENMP=ON, 

TorchVision: 0.11.2
OpenCV: 4.5.5
MMCV: 1.4.0
MMCV Compiler: GCC 7.3
MMCV CUDA Compiler: 10.2
MMDetection: 2.19.0+4d41c05

Error traceback
The error logs below:

2022-01-28 22:33:24,528 - mmdet - INFO - Evaluating bbox...
Loading and preparing results...2022-01-28 22:33:24,529 - mmdet - ERROR - The testing results of the whole dataset is empty.
/root/anaconda3/lib/python3.9/site-packages/mmcv/runner/hooks/evaluation.py:374: UserWarning: Since `eval_res` is an empty dict, the behavior to save the best checkpoint will be skipped in this evaluation.  warnings.warn(
Traceback (most recent call last):
  File "/home/superdisk/tensorflow-great-barrier-reef/tools/train.py", line 185, in <module>    main()
  File "/home/superdisk/tensorflow-great-barrier-reef/tools/train.py", line 174, in main
    train_detector(  File "/home/superdisk/tensorflow-great-barrier-reef/mmdet/apis/train.py", line 203, in train_detector
    runner.run(data_loaders, cfg.workflow)
  File "/root/anaconda3/lib/python3.9/site-packages/mmcv/runner/epoch_based_runner.py", line 127, in run
    epoch_runner(data_loaders[i], **kwargs)
  File "/root/anaconda3/lib/python3.9/site-packages/mmcv/runner/epoch_based_runner.py", line 54, in train
    self.call_hook('after_train_epoch')
  File "/root/anaconda3/lib/python3.9/site-packages/mmcv/runner/base_runner.py", line 307, in call_hook
    getattr(hook, fn_name)(self)
  File "/root/anaconda3/lib/python3.9/site-packages/mmcv/runner/hooks/evaluation.py", line 267, in after_train_epoch
    self._do_evaluate(runner)
  File "/home/superdisk/tensorflow-great-barrier-reef/mmdet/core/evaluation/eval_hooks.py", line 60, in _do_evaluate
    self._save_ckpt(runner, key_score)
  File "/root/anaconda3/lib/python3.9/site-packages/mmcv/runner/hooks/evaluation.py", line 330, in _save_ckpt
    if self.compare_func(key_score, best_score):
  File "/root/anaconda3/lib/python3.9/site-packages/mmcv/runner/hooks/evaluation.py", line 77, in <lambda>
    rule_map = {'greater': lambda x, y: x > y, 'less': lambda x, y: x < y}
  TypeError: '>' not supported between instances of 'Nonetype' and 'float'

Bug fix

I think we can fixed with modify the eval_hooks.py , but i'm not sure that's right.

    def _do_evaluate(self, runner):
        """perform evaluation and save ckpt."""
        if not self._should_evaluate(runner):
            return

        from mmdet.apis import single_gpu_test
        results = single_gpu_test(runner.model, self.dataloader, show=False)
        runner.log_buffer.output['eval_iter_num'] = len(self.dataloader)
        key_score = self.evaluate(runner, results)
        # the key_score may be `None` so it needs to skip the action to save
        # the best checkpoint
        if self.save_best and key_score:
            self._save_ckpt(runner, key_score)
@huang-jesse huang-jesse changed the title [eval_hooks.py]ERROR - The testing results of the whole dataset is empty. [eval_hooks.py] ERROR - The testing results of the whole dataset is empty. Jan 29, 2022
@huang-jesse huang-jesse changed the title [eval_hooks.py] ERROR - The testing results of the whole dataset is empty. [eval_hooks.py] TypeError: '>' not supported between instances of 'Nonetype' and 'float' Jan 29, 2022
@jbwang1997
Copy link
Collaborator

Thank you for your report! we will discuss this situation as soon as possible.

@jbwang1997 jbwang1997 added the bug Something isn't working label Jan 29, 2022
@jbwang1997 jbwang1997 self-assigned this Jan 29, 2022
@jbwang1997
Copy link
Collaborator

Hello @LuooChen, this is really a bug in mmdetection and your modification is right.
If possible, can you create a pull request following this guideline and contribute to mmdetection? We appreciate your contribution.

@huang-jesse
Copy link
Contributor Author

Hello @LuooChen, this is really a bug in mmdetection and your modification is right. If possible, can you create a pull request following this guideline and contribute to mmdetection? We appreciate your contribution.

Ok, i will do it as soon as possible.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants