Skip to content

Commit

Permalink
add script of check dataset information
Browse files Browse the repository at this point in the history
  • Loading branch information
CoinCheung committed Jun 26, 2022
1 parent 2717c9b commit 8587c1f
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 0 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,10 @@ frankfurt_000001_079206_leftImg8bit.png,frankfurt_000001_079206_gtFine_labelIds.
...
```
Each line is a pair of training sample and ground truth image path, which are separated by a single comma `,`.
I recommand you to check the information of your dataset with the script:
```
$ python tools/check_dataset_info.py --im_root /path/to/your/data_root --im_anns /path/to/your/anno_file
```
Then you need to change the field of `im_root` and `train/val_im_anns` in the config file. I prepared a demo config file for you named [`bisenet_customer.py`](./configs/bisenet_customer.py). You can start from this conig file.


Expand Down
112 changes: 112 additions & 0 deletions tools/check_dataset_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@

import os
import os.path as osp
import argparse
from tqdm import tqdm

import cv2
import numpy as np


parse = argparse.ArgumentParser()
parse.add_argument('--im_root', dest='im_root', type=str, default='./datasets/cityscapes',)
parse.add_argument('--im_anns', dest='im_anns', type=str, default='./datasets/cityscapes/train.txt',)
args = parse.parse_args()


with open(args.im_anns, 'r') as fr:
lines = fr.read().splitlines()

n_pairs = len(lines)
impaths, lbpaths = [], []
for l in lines:
impth, lbpth = l.split(',')
impth = osp.join(args.im_root, impth)
lbpth = osp.join(args.im_root, lbpth)
impaths.append(impth)
lbpaths.append(lbpth)


## shapes
max_shape_area, min_shape_area = [0, 0], [100000, 100000]
max_shape_height, min_shape_height = [0, 0], [100000, 100000]
max_shape_width, min_shape_width = [0, 0], [100000, 100000]
max_lb_val, min_lb_val = -1, 10000000
for impth, lbpth in tqdm(zip(impaths, lbpaths), total=n_pairs):
im = cv2.imread(impth)[:, :, ::-1]
lb = cv2.imread(lbpth, 0)
assert im.shape[:2] == lb.shape

shape = lb.shape
area = shape[0] * shape[1]
if area > max_shape_area[0] * max_shape_area[1]:
max_shape_area = shape
if area < min_shape_area[0] * min_shape_area[1]:
min_shape_area = shape

if shape[0] > max_shape_height[0]:
max_shape_height = shape
if shape[0] < min_shape_height[0]:
min_shape_height = shape

if shape[1] > max_shape_width[1]:
max_shape_width = shape
if shape[1] < min_shape_width[1]:
min_shape_width = shape

max_lb_val = max(max_lb_val, np.max(lb.ravel()))
min_lb_val = min(min_lb_val, np.min(lb.ravel()))


## label info
lb_minlength = max_lb_val+1-min_lb_val
lb_hist = np.zeros(lb_minlength)
for impth in tqdm(impaths):
lb = cv2.imread(lbpth, 0).ravel() + min_lb_val
lb_hist += np.bincount(lb, minlength=lb_minlength)

lb_missing_vals = [ind + min_lb_val
for ind, el in enumerate(lb_hist.tolist()) if el == 0]
lb_ratios = (lb_hist / lb_hist.sum()).tolist()


## pixel mean/std
rgb_mean = np.zeros(3).astype(np.float32)
n_pixels = 0
for impth in tqdm(impaths):
im = cv2.imread(impth)[:, :, ::-1].astype(np.float32)
im = im.reshape(-1, 3)
n_pixels += im.shape[0]
rgb_mean += im.sum(axis=0)
rgb_mean = rgb_mean / n_pixels

rgb_std = np.zeros(3).astype(np.float32)
for impth in tqdm(impaths):
im = cv2.imread(impth)[:, :, ::-1].astype(np.float32)
im = im.reshape(-1, 3)

a = (im - rgb_mean.reshape(1, 3)) ** 2
rgb_std += a.sum(axis=0)
rgb_std = (rgb_std / n_pixels) ** (0.5)


print(f'there are {n_pairs} lines in {args.im_anns}, which means {n_pairs} image/label image pairs')
print('\n')

print('max and min image shapes by area are: ')
print(f'\t{max_shape_area}, {min_shape_area}')
print('max and min image shapes by height are: ')
print(f'\t{max_shape_height}, {min_shape_height}')
print('max and min image shapes by width are: ')
print(f'\t{max_shape_width}, {min_shape_width}')
print('\n')

print(f'label values are within range of ({min_lb_val}, {max_lb_val})')
print('label values that are missing: ')
print('\t', lb_missing_vals)
print('ratios of each label value: ')
print('\t', lb_ratios)
print('\n')

print('pixel mean rgb: ', mean)
print('pixel std rgb: ', std)

0 comments on commit 8587c1f

Please sign in to comment.