Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rectangle training and training your own dataset #51

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ To test LCNN on your own images, you need download the pre-trained models and ex

```Bash
python ./demo.py -d 0 config/wireframe.yaml <path-to-pretrained-pth> <path-to-image>
python ./demo.py -d 0 config/wireframe.yaml /home/zengxh/workspace/lcnn/config/190418-201834-f8934c6-lr4d10-312k.pth.tar /home/zengxh/workspace/lcnn/data/wireframe/train/00559828_3.png
```
Here, `-d 0` is specifying the GPU ID used for evaluation, and you can specify `-d ""` to force CPU inference.

Expand Down Expand Up @@ -200,3 +201,10 @@ If you find L-CNN useful in your research, please consider citing:
year={2019}
}
```

/home/zengxh/anaconda3/envs/CreepageDistance/bin/python3.8 /home/zengxh/workspace/lcnn/dataset/train_test_split.py
/home/zengxh/anaconda3/envs/CreepageDistance/bin/python3.8 /home/zengxh/workspace/lcnn/dataset/wireframe.py /home/zengxh/datasets/creepageDistance /home/zengxh/datasets/creepageDistance_wireframe
/home/zengxh/anaconda3/envs/CreepageDistance/bin/python3.8 /home/zengxh/workspace/lcnn/train.py -d 0 --identifier baseline config/wireframe.yaml
/home/zengxh/anaconda3/envs/CreepageDistance/bin/python3.8 /home/zengxh/workspace/lcnn/demo.py -d 0 config/wireframe.yaml /home/zengxh/workspace/lcnn/logs/210706-112447-88f281a-baseline/checkpoint_best.pth /home/zengxh/datasets/creepageDistance_wireframe/valid/7507639237000304_0_t_0.png


2 changes: 1 addition & 1 deletion config/wireframe.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
io:
logdir: logs/
datadir: data/wireframe/
datadir: /home/zengxh/datasets/creepageDistance_wireframe/
resume_from:
num_workers: 4
tensorboard_port: 0
Expand Down
17 changes: 17 additions & 0 deletions dataset/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# !/usr/bin/env python
# -- coding: utf-8 --
# @Author zengxiaohui
# Datatime:7/6/2021 10:51 AM
# @File:constants.py
# 用于归一化的宽度值
NORMALIZATION_WIDTH = 64
NORMALIZATION_HEIGHT = 512
# 像素最大值为255
PIXS_MAX_VALUE = 255.0
# 数据类型
TB_DATATYPE = "tb"
LR_DATATYPE = "lr"
# 准确率容错距离
ACC_PX_THRESH=16
# 随机种子
RANDOM_SEED = 1024
9 changes: 9 additions & 0 deletions dataset/creepageDistance.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
datasets:
lr:
allDatas: /home/zengxh/medias/data/ext/creepageDistance/datasets/lr/org
# h*w
img_size: [512,64]
tb:
allDatas: /home/zengxh/medias/data/ext/creepageDistance/datasets/tb/org
# h*w
img_size: [512,64]
168 changes: 168 additions & 0 deletions dataset/train_test_split.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
import argparse
import os
import shutil
import random

import cv2
import yaml
from imutils import paths
from sklearn.model_selection import train_test_split
from tqdm import tqdm

from dataset.constants import NORMALIZATION_WIDTH, NORMALIZATION_HEIGHT, RANDOM_SEED
from python_developer_tools.cv.datasets.datasets_utils import resize_image, letterbox
from python_developer_tools.files.common import mkdir, get_filename_suf_pix
from python_developer_tools.files.json_utils import read_json_file, save_json_file

def createDatasets(datasets, dirname):
dataDir = os.path.join(data_dict_tmp[dirname])
if not os.path.exists(dataDir):
os.makedirs(dataDir)
for datapath in datasets:
shutil.copy(datapath.replace(".jpg",".json"), dataDir)
shutil.copy(datapath, dataDir)

def get_origin_image_points(imagePath):
img = cv2.imread(imagePath)
jsonfile = imagePath.replace(".jpg", ".json")
json_cont = read_json_file(jsonfile)
labels_tmp = [0, 0, 0, 0, 0, 0, 0, 0]
for shapes in json_cont["shapes"]:
label = shapes["label"]
points = shapes["points"]
if json_cont["imageHeight"] > json_cont["imageWidth"]:
if label == "min":
if points[0][1] < points[1][1]:
labels_tmp[4] = points[0][0]
labels_tmp[5] = points[0][1]
labels_tmp[6] = points[1][0]
labels_tmp[7] = points[1][1]
else:
labels_tmp[4] = points[1][0]
labels_tmp[5] = points[1][1]
labels_tmp[6] = points[0][0]
labels_tmp[7] = points[0][1]
if label == "max":
if points[0][1] < points[1][1]:
labels_tmp[0] = points[0][0]
labels_tmp[1] = points[0][1]
labels_tmp[2] = points[1][0]
labels_tmp[3] = points[1][1]
else:
labels_tmp[0] = points[1][0]
labels_tmp[1] = points[1][1]
labels_tmp[2] = points[0][0]
labels_tmp[3] = points[0][1]
else:
if label == "min":
if points[0][0] < points[1][0]:
labels_tmp[4] = points[0][0]
labels_tmp[5] = points[0][1]
labels_tmp[6] = points[1][0]
labels_tmp[7] = points[1][1]
else:
labels_tmp[4] = points[1][0]
labels_tmp[5] = points[1][1]
labels_tmp[6] = points[0][0]
labels_tmp[7] = points[0][1]
if label == "max":
if points[0][0] < points[1][0]:
labels_tmp[0] = points[0][0]
labels_tmp[1] = points[0][1]
labels_tmp[2] = points[1][0]
labels_tmp[3] = points[1][1]
else:
labels_tmp[0] = points[1][0]
labels_tmp[1] = points[1][1]
labels_tmp[2] = points[0][0]
labels_tmp[3] = points[0][1]
return img,labels_tmp

def label_transpose_1(label_o,w0,h0):
# tb 顺时针旋转90°
new_label = [0, 0, 0, 0, 0, 0, 0, 0]
new_label[0] = w0-label_o[5]
new_label[1] = label_o[4]
new_label[2] = w0-label_o[7]
new_label[3] = label_o[6]
new_label[4] = w0-label_o[1]
new_label[5] = label_o[0]
new_label[6] = w0-label_o[3]
new_label[7] = label_o[2]
return new_label

def labels_convert_train(label,w0,h0,w1,h1,w2,h2,padw,padh):
new_label = [0,0,0,0,0,0,0,0]
for i,_label in enumerate(label):
if i in [0,2,4,6]:
new_label[i] = ((_label * w1 / w0) * w2 ) / w1 + padw
else:
new_label[i] = ((_label * h1 / h0) * h2 ) / h1 + padh
# label = [i / w0 for i in label]
# label = [i * w1 / w0 for i in label]
# label = [(i * w2+ padw) / w1 for i in label]
# label = [i / NORMALIZATION_WIDTH for i in label]
return new_label

def get_dict_json(imagePath):
filename, filedir, filesuffix, filenamestem = get_filename_suf_pix(imagePath)
img, labels_tmp = get_origin_image_points(imagePath)
if key == "tb":
img = cv2.transpose(img)
img = cv2.flip(img, 1)
h0, w0 = img.shape[:2] # orig hw
labels_tmp = label_transpose_1(labels_tmp, w0, h0)
# _ = cv2.line(img, (int(labels_tmp[0]), int(labels_tmp[1])), (int(labels_tmp[2]), int(labels_tmp[3])),
# (0, 255, 0), thickness=2)
# _ = cv2.line(img, (int(labels_tmp[4]), int(labels_tmp[5])), (int(labels_tmp[6]), int(labels_tmp[7])),
# (0, 0, 255), thickness=2)
# cv2.imwrite("sdf.jpg", img)
h0, w0 = img.shape[:2]

img = resize_image(img, [NORMALIZATION_HEIGHT, NORMALIZATION_WIDTH])
h1, w1, _ = img.shape
img, ratio, pad = letterbox(img, [NORMALIZATION_HEIGHT, NORMALIZATION_WIDTH], auto=False, scaleup=True)
#letterbox(img, img_size, auto=False,scaleFill=True) # 会填充边缘 letterbox(img, self.opt.img_size, auto=False, scaleup=False)
labels_tmp = labels_convert_train(labels_tmp, w0, h0, w1, h1, ratio[0] * w1, ratio[1] * h1, pad[0], pad[1])

h2, w2 = img.shape[:2]
dict_json = {"filename": filename,
"lines": [[labels_tmp[0], labels_tmp[1], labels_tmp[2], labels_tmp[3]],
[labels_tmp[4], labels_tmp[5], labels_tmp[6], labels_tmp[7]]],
"height": h2, "width": w2}
cv2.imwrite(os.path.join(images_dir, filename), img)
return dict_json

if __name__ == '__main__':
random.seed(RANDOM_SEED)
parser = argparse.ArgumentParser(description="获取杭州人工认为有缺陷大图")
parser.add_argument('--data',
default=r"creepageDistance.yaml",
help="没有分的文件夹")
parser.add_argument('--datasets_path',
default=r"/home/zengxh/datasets/creepageDistance",
help="没有分的文件夹")
opt = parser.parse_args()

images_dir = os.path.join(opt.datasets_path,"images")
mkdir(images_dir)

with open(opt.data,encoding="utf-8") as f:
data_dict = yaml.load(f, Loader=yaml.FullLoader) # data dict

train_json = []
valid_json = []
for (key, data_dict_tmp) in data_dict["datasets"].items():
nameImgs = list(paths.list_images(os.path.join(data_dict_tmp["allDatas"])))
X_train, X_test_val, _, _ = train_test_split(nameImgs, nameImgs, test_size=0.2, random_state=RANDOM_SEED)

for imagePath in X_train:
dict_json = get_dict_json(imagePath)
train_json.append(dict_json)

for imagePath in X_test_val:
dict_json = get_dict_json(imagePath)
valid_json.append(dict_json)

save_json_file(os.path.join(opt.datasets_path,"train.json"),train_json)
save_json_file(os.path.join(opt.datasets_path,"valid.json"),valid_json)
12 changes: 7 additions & 5 deletions dataset/wireframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
from docopt import docopt
from scipy.ndimage import zoom

from dataset.constants import NORMALIZATION_HEIGHT, NORMALIZATION_WIDTH

try:
sys.path.append(".")
sys.path.append("..")
Expand All @@ -44,13 +46,13 @@ def to_int(x):


def save_heatmap(prefix, image, lines):
im_rescale = (512, 512)
heatmap_scale = (128, 128)
im_rescale = (NORMALIZATION_WIDTH, NORMALIZATION_HEIGHT)
heatmap_scale = (int(NORMALIZATION_WIDTH / 4), int(NORMALIZATION_HEIGHT / 4))

fy, fx = heatmap_scale[1] / image.shape[0], heatmap_scale[0] / image.shape[1]
jmap = np.zeros((1,) + heatmap_scale, dtype=np.float32)
joff = np.zeros((1, 2) + heatmap_scale, dtype=np.float32)
lmap = np.zeros(heatmap_scale, dtype=np.float32)
jmap = np.zeros((1,) + (heatmap_scale[1],heatmap_scale[0]), dtype=np.float32)
joff = np.zeros((1, 2) + (heatmap_scale[1],heatmap_scale[0]), dtype=np.float32)
lmap = np.zeros((heatmap_scale[1],heatmap_scale[0]), dtype=np.float32)

lines[:, :, 0] = np.clip(lines[:, :, 0] * fx, 0, heatmap_scale[0] - 1e-4)
lines[:, :, 1] = np.clip(lines[:, :, 1] * fy, 0, heatmap_scale[1] - 1e-4)
Expand Down
12 changes: 7 additions & 5 deletions dataset/york.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
from scipy.io import loadmat
from scipy.ndimage import zoom

from dataset.constants import NORMALIZATION_WIDTH, NORMALIZATION_HEIGHT

try:
sys.path.append(".")
sys.path.append("..")
Expand All @@ -47,13 +49,13 @@ def to_int(x):


def save_heatmap(prefix, image, lines):
im_rescale = (512, 512)
heatmap_scale = (128, 128)
im_rescale = (NORMALIZATION_WIDTH, NORMALIZATION_HEIGHT)
heatmap_scale = (int(NORMALIZATION_WIDTH / 4), int(NORMALIZATION_HEIGHT / 4))

fy, fx = heatmap_scale[1] / image.shape[0], heatmap_scale[0] / image.shape[1]
jmap = np.zeros((1,) + heatmap_scale, dtype=np.float32)
joff = np.zeros((1, 2) + heatmap_scale, dtype=np.float32)
lmap = np.zeros(heatmap_scale, dtype=np.float32)
jmap = np.zeros((1,) + (heatmap_scale[1],heatmap_scale[0]), dtype=np.float32)
joff = np.zeros((1, 2) + (heatmap_scale[1],heatmap_scale[0]), dtype=np.float32)
lmap = np.zeros((heatmap_scale[1],heatmap_scale[0]), dtype=np.float32)

lines[:, :, 0] = np.clip(lines[:, :, 0] * fx, 0, heatmap_scale[0] - 1e-4)
lines[:, :, 1] = np.clip(lines[:, :, 1] * fy, 0, heatmap_scale[1] - 1e-4)
Expand Down
44 changes: 23 additions & 21 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from docopt import docopt

import lcnn
from dataset.constants import NORMALIZATION_WIDTH, NORMALIZATION_HEIGHT
from lcnn.config import C, M
from lcnn.models.line_vectorizer import LineVectorizer
from lcnn.models.multitask_learner import MultitaskHead, MultitaskLearner
Expand Down Expand Up @@ -89,7 +90,8 @@ def main():
if im.ndim == 2:
im = np.repeat(im[:, :, None], 3, 2)
im = im[:, :, :3]
im_resized = skimage.transform.resize(im, (512, 512)) * 255
im_resized = skimage.transform.resize(im, (NORMALIZATION_HEIGHT,NORMALIZATION_WIDTH )) * 255
# skimage.io.imsave('cat.jpg', im_resized)
image = (im_resized - M.image.mean) / M.image.stddev
image = torch.from_numpy(np.rollaxis(image, 2)[None].copy()).float()
with torch.no_grad():
Expand All @@ -104,14 +106,14 @@ def main():
}
],
"target": {
"jmap": torch.zeros([1, 1, 128, 128]).to(device),
"joff": torch.zeros([1, 1, 2, 128, 128]).to(device),
"jmap": torch.zeros([1, 1, int(NORMALIZATION_HEIGHT / 4), int(NORMALIZATION_WIDTH / 4)]).to(device),
"joff": torch.zeros([1, 1, 2, int(NORMALIZATION_HEIGHT / 4), int(NORMALIZATION_WIDTH / 4)]).to(device),
},
"mode": "testing",
}
H = model(input_dict)["preds"]

lines = H["lines"][0].cpu().numpy() / 128 * im.shape[:2]
lines = H["lines"][0].cpu().numpy() / (int(NORMALIZATION_HEIGHT / 4),int(NORMALIZATION_WIDTH / 4)) * im.shape[:2]
scores = H["score"][0].cpu().numpy()
for i in range(1, len(lines)):
if (lines[i] == lines[0]).all():
Expand All @@ -122,23 +124,23 @@ def main():
# postprocess lines to remove overlapped lines
diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False)

for i, t in enumerate([0.94, 0.95, 0.96, 0.97, 0.98, 0.99]):
plt.gca().set_axis_off()
plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
plt.margins(0, 0)
for (a, b), s in zip(nlines, nscores):
if s < t:
continue
plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s)
plt.scatter(a[1], a[0], **PLTOPTS)
plt.scatter(b[1], b[0], **PLTOPTS)
plt.gca().xaxis.set_major_locator(plt.NullLocator())
plt.gca().yaxis.set_major_locator(plt.NullLocator())
plt.imshow(im)
plt.savefig(imname.replace(".png", f"-{t:.02f}.svg"), bbox_inches="tight")
plt.show()
plt.close()
print(nlines)
# for i, t in enumerate([0.01, 0.95, 0.96, 0.97, 0.98, 0.99]):
plt.gca().set_axis_off()
plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
plt.margins(0, 0)
for (a, b), s in zip(nlines, nscores):
# if s < t:
# continue
plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s)
plt.scatter(a[1], a[0], **PLTOPTS)
plt.scatter(b[1], b[0], **PLTOPTS)
plt.gca().xaxis.set_major_locator(plt.NullLocator())
plt.gca().yaxis.set_major_locator(plt.NullLocator())
plt.imshow(im)
plt.savefig(imname.replace(".png", f"-{0.1:.02f}.svg"), bbox_inches="tight")
plt.show()
plt.close()


if __name__ == "__main__":
Expand Down
Loading