Skip to content

Commit

Permalink
Merge 96bd06d into a9fd839
Browse files Browse the repository at this point in the history
  • Loading branch information
yaochaorui authored Feb 18, 2021
2 parents a9fd839 + 96bd06d commit 2250432
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 9 deletions.
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,6 @@ line_length = 79
multi_line_output = 0
known_standard_library = pkg_resources,setuptools
known_first_party = mmedit
known_third_party =PIL,cv2,lmdb,mmcv,numpy,onnx,onnxruntime,pymatting,pytest,scipy,titlecase,torch,torchvision,tqdm
known_third_party =PIL,cv2,lmdb,mmcv,numpy,onnx,onnxruntime,pymatting,pytest,scipy,titlecase,torch,torchvision
no_lines_before = STDLIB,LOCALFOLDER
default_section = THIRDPARTY
17 changes: 9 additions & 8 deletions tools/data/matting/comp1k/extend_fg.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@
import os.path as osp
import re
import subprocess
from multiprocessing import Pool

import mmcv
import numpy as np
from PIL import Image
from pymatting import estimate_foreground_ml, load_image
from tqdm import tqdm


def fix_png_file(filename, folder):
Expand Down Expand Up @@ -70,6 +69,10 @@ def extend(self, fg_name):
fg = Image.fromarray(np.uint8(F * 255))
fg.save(extended_path)
fix_png_file(osp.basename(extended_path), osp.dirname(extended_path))
data_info = dict()
data_info['alpha_path'] = alpha_path
data_info['fg_extended_path'] = extended_path
return data_info


def parse_args():
Expand Down Expand Up @@ -112,16 +115,14 @@ def main():
os.makedirs(p, exist_ok=True)

fg_names = osp.join(dir_prefix, f'{fname_prefix}_fg_names.txt')
save_json_path = f'{fname_prefix}_list.json'
fg_names = open(osp.join(data_root, fg_names)).readlines()
fg_iter = iter(fg_names)
num = len(fg_names)

extend_fg = ExtendFg(data_root, fg_dirs, alpha_dirs)
with Pool(processes=args.nproc) as p:
with tqdm(total=num) as pbar:
for i, _ in tqdm(
enumerate(p.imap_unordered(extend_fg.extend, fg_iter))):
pbar.update()
data_infos = mmcv.track_parallel_progress(extend_fg.extend, list(fg_iter),
args.nproc)
mmcv.dump(data_infos, osp.join(data_root, save_json_path))

print('train done')

Expand Down

0 comments on commit 2250432

Please sign in to comment.