From 15b2a4e16c12fd74286efe904f4ca7a122e7d5d6 Mon Sep 17 00:00:00 2001 From: coincheung <867153576@qq.com> Date: Mon, 27 Jun 2022 09:07:01 +0000 Subject: [PATCH] modify script of check dataset --- README.md | 1 + tools/check_dataset_info.py | 49 ++++++++++++++++++++++--------------- 2 files changed, 30 insertions(+), 20 deletions(-) diff --git a/README.md b/README.md index 4b32b4f..3068e92 100644 --- a/README.md +++ b/README.md @@ -113,6 +113,7 @@ I recommand you to check the information of your dataset with the script: ``` $ python tools/check_dataset_info.py --im_root /path/to/your/data_root --im_anns /path/to/your/anno_file ``` +This will print some of the information of your dataset. Then you need to change the field of `im_root` and `train/val_im_anns` in the config file. I prepared a demo config file for you named [`bisenet_customer.py`](./configs/bisenet_customer.py). You can start from this conig file. diff --git a/tools/check_dataset_info.py b/tools/check_dataset_info.py index 030a75b..183ba44 100644 --- a/tools/check_dataset_info.py +++ b/tools/check_dataset_info.py @@ -11,8 +11,11 @@ parse = argparse.ArgumentParser() parse.add_argument('--im_root', dest='im_root', type=str, default='./datasets/cityscapes',) parse.add_argument('--im_anns', dest='im_anns', type=str, default='./datasets/cityscapes/train.txt',) +parse.add_argument('--lb_ignore', dest='lb_ignore', type=int, default=255) args = parse.parse_args() +lb_ignore = args.lb_ignore + with open(args.im_anns, 'r') as fr: lines = fr.read().splitlines() @@ -54,15 +57,20 @@ if shape[1] < min_shape_width[1]: min_shape_width = shape - max_lb_val = max(max_lb_val, np.max(lb.ravel())) - min_lb_val = min(min_lb_val, np.min(lb.ravel())) - + lb = lb[lb != lb_ignore] + if lb.size > 0: + max_lb_val = max(max_lb_val, np.max(lb)) + min_lb_val = min(min_lb_val, np.min(lb)) +min_lb_val = 0 +max_lb_val = 181 +lb_minlength = 182 ## label info lb_minlength = max_lb_val+1-min_lb_val lb_hist = np.zeros(lb_minlength) -for impth in tqdm(impaths): - lb = cv2.imread(lbpth, 0).ravel() + min_lb_val +for lbpth in tqdm(lbpaths): + lb = cv2.imread(lbpth, 0) + lb = lb[lb != lb_ignore] + min_lb_val lb_hist += np.bincount(lb, minlength=lb_minlength) lb_missing_vals = [ind + min_lb_val @@ -75,38 +83,39 @@ n_pixels = 0 for impth in tqdm(impaths): im = cv2.imread(impth)[:, :, ::-1].astype(np.float32) - im = im.reshape(-1, 3) + im = im.reshape(-1, 3) / 255. n_pixels += im.shape[0] rgb_mean += im.sum(axis=0) -rgb_mean = rgb_mean / n_pixels +rgb_mean = (rgb_mean / n_pixels) rgb_std = np.zeros(3).astype(np.float32) for impth in tqdm(impaths): im = cv2.imread(impth)[:, :, ::-1].astype(np.float32) - im = im.reshape(-1, 3) + im = im.reshape(-1, 3) / 255. a = (im - rgb_mean.reshape(1, 3)) ** 2 rgb_std += a.sum(axis=0) -rgb_std = (rgb_std / n_pixels) ** (0.5) +rgb_std = (rgb_std / n_pixels) ** 0.5 + +rgb_mean = rgb_mean.tolist() +rgb_std = rgb_std.tolist() +print('\n') print(f'there are {n_pairs} lines in {args.im_anns}, which means {n_pairs} image/label image pairs') print('\n') -print('max and min image shapes by area are: ') -print(f'\t{max_shape_area}, {min_shape_area}') -print('max and min image shapes by height are: ') -print(f'\t{max_shape_height}, {min_shape_height}') -print('max and min image shapes by width are: ') -print(f'\t{max_shape_width}, {min_shape_width}') +print(f'max and min image shapes by area are: {max_shape_area}, {min_shape_area}') +print(f'max and min image shapes by height are: {max_shape_height}, {min_shape_height}') +print(f'max and min image shapes by width are: {max_shape_width}, {min_shape_width}') print('\n') -print(f'label values are within range of ({min_lb_val}, {max_lb_val})') -print('label values that are missing: ') -print('\t', lb_missing_vals) +print(f'we ignore label value of {args.lb_ignore} in label images') +print(f'label values are within range of [{min_lb_val}, {max_lb_val}]') +print(f'label values that are missing: {lb_missing_vals}') print('ratios of each label value: ') print('\t', lb_ratios) print('\n') -print('pixel mean rgb: ', mean) -print('pixel std rgb: ', std) +print('pixel mean rgb: ', rgb_mean) +print('pixel std rgb: ', rgb_std)