Skip to content

Commit

Permalink
modify script of check dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
CoinCheung committed Jun 27, 2022
1 parent 8587c1f commit 15b2a4e
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 20 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ I recommand you to check the information of your dataset with the script:
```
$ python tools/check_dataset_info.py --im_root /path/to/your/data_root --im_anns /path/to/your/anno_file
```
This will print some of the information of your dataset.
Then you need to change the field of `im_root` and `train/val_im_anns` in the config file. I prepared a demo config file for you named [`bisenet_customer.py`](./configs/bisenet_customer.py). You can start from this conig file.


Expand Down
49 changes: 29 additions & 20 deletions tools/check_dataset_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@
parse = argparse.ArgumentParser()
parse.add_argument('--im_root', dest='im_root', type=str, default='./datasets/cityscapes',)
parse.add_argument('--im_anns', dest='im_anns', type=str, default='./datasets/cityscapes/train.txt',)
parse.add_argument('--lb_ignore', dest='lb_ignore', type=int, default=255)
args = parse.parse_args()

lb_ignore = args.lb_ignore


with open(args.im_anns, 'r') as fr:
lines = fr.read().splitlines()
Expand Down Expand Up @@ -54,15 +57,20 @@
if shape[1] < min_shape_width[1]:
min_shape_width = shape

max_lb_val = max(max_lb_val, np.max(lb.ravel()))
min_lb_val = min(min_lb_val, np.min(lb.ravel()))

lb = lb[lb != lb_ignore]
if lb.size > 0:
max_lb_val = max(max_lb_val, np.max(lb))
min_lb_val = min(min_lb_val, np.min(lb))

min_lb_val = 0
max_lb_val = 181
lb_minlength = 182
## label info
lb_minlength = max_lb_val+1-min_lb_val
lb_hist = np.zeros(lb_minlength)
for impth in tqdm(impaths):
lb = cv2.imread(lbpth, 0).ravel() + min_lb_val
for lbpth in tqdm(lbpaths):
lb = cv2.imread(lbpth, 0)
lb = lb[lb != lb_ignore] + min_lb_val
lb_hist += np.bincount(lb, minlength=lb_minlength)

lb_missing_vals = [ind + min_lb_val
Expand All @@ -75,38 +83,39 @@
n_pixels = 0
for impth in tqdm(impaths):
im = cv2.imread(impth)[:, :, ::-1].astype(np.float32)
im = im.reshape(-1, 3)
im = im.reshape(-1, 3) / 255.
n_pixels += im.shape[0]
rgb_mean += im.sum(axis=0)
rgb_mean = rgb_mean / n_pixels
rgb_mean = (rgb_mean / n_pixels)

rgb_std = np.zeros(3).astype(np.float32)
for impth in tqdm(impaths):
im = cv2.imread(impth)[:, :, ::-1].astype(np.float32)
im = im.reshape(-1, 3)
im = im.reshape(-1, 3) / 255.

a = (im - rgb_mean.reshape(1, 3)) ** 2
rgb_std += a.sum(axis=0)
rgb_std = (rgb_std / n_pixels) ** (0.5)
rgb_std = (rgb_std / n_pixels) ** 0.5

rgb_mean = rgb_mean.tolist()
rgb_std = rgb_std.tolist()


print('\n')
print(f'there are {n_pairs} lines in {args.im_anns}, which means {n_pairs} image/label image pairs')
print('\n')

print('max and min image shapes by area are: ')
print(f'\t{max_shape_area}, {min_shape_area}')
print('max and min image shapes by height are: ')
print(f'\t{max_shape_height}, {min_shape_height}')
print('max and min image shapes by width are: ')
print(f'\t{max_shape_width}, {min_shape_width}')
print(f'max and min image shapes by area are: {max_shape_area}, {min_shape_area}')
print(f'max and min image shapes by height are: {max_shape_height}, {min_shape_height}')
print(f'max and min image shapes by width are: {max_shape_width}, {min_shape_width}')
print('\n')

print(f'label values are within range of ({min_lb_val}, {max_lb_val})')
print('label values that are missing: ')
print('\t', lb_missing_vals)
print(f'we ignore label value of {args.lb_ignore} in label images')
print(f'label values are within range of [{min_lb_val}, {max_lb_val}]')
print(f'label values that are missing: {lb_missing_vals}')
print('ratios of each label value: ')
print('\t', lb_ratios)
print('\n')

print('pixel mean rgb: ', mean)
print('pixel std rgb: ', std)
print('pixel mean rgb: ', rgb_mean)
print('pixel std rgb: ', rgb_std)

0 comments on commit 15b2a4e

Please sign in to comment.