Skip to content

Commit

Permalink
bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
hua committed Nov 25, 2018
1 parent 0af01ee commit 18659b7
Show file tree
Hide file tree
Showing 7 changed files with 69 additions and 245 deletions.
64 changes: 36 additions & 28 deletions .idea/workspace.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Binary file modified lib/__pycache__/common.cpython-36.pyc
Binary file not shown.
2 changes: 1 addition & 1 deletion lib/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

def print_start():
"""
打印 开始 logog
打印 开始 logo
:return:
"""

Expand Down
97 changes: 30 additions & 67 deletions lib/gen/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,19 @@
from lib.gen.static import *

# 路径先采用绝对路径
test_gen_path = '/Volumes/doc/projects/ml/vin_recognize/static/test/'
images_path = "/Volumes/doc/projects/ml/vin_recognize/static/images/"
font_path = '/Volumes/doc/projects/ml/vin_recognize/static/font/'
bad_path = '/Volumes/doc/projects/ml/vin_recognize/static/bad/'
predict_path = '/Volumes/doc/projects/ml/vin_recognize/static/predict/'
font = 'OCR-B.ttf'


class GenVin:

def __init__(self, height, width,predict_path=predict_path,fontSize=30,
NoPlates=bad_path,font_path=font_path,
images_path=images_path,font=font):
def __init__(self, height, width, predict_path=predict_path, fontSize=30,
NoPlates=bad_path, font_path=font_path,
images_path=images_path, font=font):

self.fontSize = fontSize
self.font = ImageFont.truetype(font_path + font, self.fontSize, 0)
Expand Down Expand Up @@ -56,11 +58,12 @@ def draw(self, val):
:param val:
:return:
"""
height,width = self.bg.shape
height, width = self.bg.shape
offset = 0
base = int((width - self.max_size*self.fontSize)/2)
base = int((width - self.max_size * self.fontSize) / 2)
for i, item in enumerate(val):
self.bg[int((height-self.fontSize)/2):int((height+self.fontSize)/2), base: base + self.fontSize] = self.draw_item(item, 0, 0)
self.bg[int((height - self.fontSize) / 2):int((height + self.fontSize) / 2),
base: base + self.fontSize] = self.draw_item(item, 0, 0)
base += self.fontSize + offset
return self.bg

Expand All @@ -86,13 +89,13 @@ def GenCh(self, val, img):
"""
img = Image.new("RGB", (self.fontSize, self.bg.shape[1]), (0, 0, 0))
draw = ImageDraw.Draw(img)
draw.text((0, int(self.bg.shape[1]-self.fontSize)/2), val, (255, 255), font=self.font)
draw.text((0, int(self.bg.shape[1] - self.fontSize) / 2), val, (255, 255), font=self.font)
A = np.array(img)
return A

def random_start(self,seed):
def random_start(self, seed):
ran = r(seed)
return ran%3 == 0
return ran % 3 == 0

def generate(self, text):
fg = self.draw(text)
Expand Down Expand Up @@ -125,7 +128,7 @@ def gen_sample(self):
return label, img[:, :, 2] # 返回的label为标签,img为深度为3的图像像素

def gen_img(self):
text, label = self.get_text_v2()
text, label = self.get_text()
img = self.generate(text)
img = cv2.resize(img, (self.width, self.height))
img = np.multiply(img, 1 / 255.0) # [height,width,channel]
Expand All @@ -135,13 +138,13 @@ def gen_img(self):
def get_next_batch(self, batch_size=128):
batch_x = np.zeros([batch_size, self.width * self.height])
batch_y = np.zeros([batch_size, self.len * self.max_size])
labels = []
labels = []
for i in range(batch_size):
image, text, vec = self.gen_img()
labels.append(text)
batch_x[i, :] = image.reshape((self.width * self.height))
batch_y[i, :] = vec
return batch_x, batch_y,labels
return batch_x, batch_y, labels

def get_predict_source(self):
items = os.listdir(self.predict_path)
Expand All @@ -150,69 +153,44 @@ def get_predict_source(self):
textes = []
for i, item in enumerate(items):
textes.append(item.split("/")[-1])
img = cv2.imread(os.path.join(path,item), cv2.IMREAD_GRAYSCALE)
img = cv2.resize(img,(self.width,self.height))
img = cv2.imread(os.path.join(self.predict_path, item), cv2.IMREAD_GRAYSCALE)
img = cv2.resize(img, (self.width, self.height))
batch_x[i, :] = img.reshape((self.width * self.height))

return batch_x,textes
return batch_x, textes


def text2vec(self,text):
def text2vec(self, text):
text_len = len(text)
if text_len > self.max_size:
raise ValueError('字符最长 ',self.max_size,'个字符')
raise ValueError('字符最长 ', self.max_size, '个字符')
vector = np.zeros(self.max_size * self.len)
for i, c in enumerate(text):
idx = i * self.len + chars_indexes[c]
vector[idx] = 1
return vector

# 向量转回文本
def vec2text(self,vec):
def vec2text(self, vec):
char_pos = vec.nonzero()[0]
text = []
for i, c in enumerate(char_pos):
char_idx = c % self.len
text.append(chars[char_idx])
return "".join(text)

def genBatch(self, batchSize, outputPath):
if (not os.path.exists(outputPath)):
os.mkdir(outputPath)
for i in range(batchSize):
# plateStr = self.genPlateString(-1, -1)
plateStr, label = self.get_text()
img = self.generate(plateStr)
img = cv2.resize(img, (self.width, self.height))
filename = os.path.join(outputPath, str(i).zfill(4) + '.' + "".join(plateStr) + ".jpg")
cv2.imwrite(filename, img)

def genBatch_v2(self, batchSize, outputPath):
def gen_Batch(self, batchSize, outputPath):
if (not os.path.exists(outputPath)):
os.mkdir(outputPath)

for i in range(batchSize):
# plateStr = self.genPlateString(-1, -1)
plateStr, label = self.get_text_v2()
plateStr, label = self.get_text()
img = self.generate(plateStr)
img = cv2.resize(img, (self.width, self.height))
filename = os.path.join(outputPath, str(i).zfill(4) + '.' + "".join(plateStr) + ".png")
cv2.imwrite(filename, img)

def get_text(self):
text = []
label = []
text += (self.get_random_char(STATE_CODE, 1)) # 生成车牌
text += (self.get_random_char(chars, 2))
text += (self.get_random_char(FOUR_CODE, 1)) # 4
text += (self.get_random_char(chars, 4)) # 8
text += (self.get_random_char(NINE_CODE, 1)) # 9
text += (self.get_random_char(chars, 8))
for item in text:
label.append(chars_indexes[item])
return ''.join(text), np.array(label)

def get_text_v2(self):
"""
# 随机生成字串,长度固定
# 返回text,及对应的向量
Expand All @@ -226,11 +204,6 @@ def get_text_v2(self):
text += (self.get_random_char(chars, 4)) # 8
text += (self.get_random_char(NINE_CODE, 1)) # 9
text += (self.get_random_char(chars, 8))
# size = random.randint(1, self.max_size)
# size = self.max_size
# for i, c in enumerate(text):
# vec = self.text2vec(c)
# vecs[i * self.len:(i + 1) * self.len] = np.copy(vec)
return text, self.text2vec(text)

def get_random_char(self, chars, size):
Expand All @@ -240,21 +213,11 @@ def get_random_char(self, chars, size):
for i in ints: result.append(chars[i])
return result

# generator = GenVin("font.ttf", 49, 258, fontSize=30,NoPlates=bad_path)
# if __name__ == '__main__':
# generator = GenVin("./font/DejaVuSansMono.ttf",80,440)
# print("".join(generator.get_text()))
# generator = GenVin("./font/Bitter-Regular.ttf")
# generator = GenVin("./font/SimHei.ttf")
# label,img = generator.gen_sample()
# print(label)d
# cv2.imshow("test",img)
# cv2.waitKeyEx(0)
# generator.genBatch_v2(3, "../../data/train/energy")
# path = "/Volumes/doc/projects/ml/ocr_tensorflow_cnn/data/train/energy/"
# for item in os.listdir(path):
# img = cv2.imread(os.path.join(path,item), cv2.IMREAD_GRAYSCALE)
# print(img.shape)
# cv2.imshow("123", img)
# cv2.waitKey(0)

generator = GenVin(49, 258)
if __name__ == '__main__':
# generator = GenVin("./font/DejaVuSansMono.ttf",80,440)
# print("".join(generator.get_text()))
# generator = GenVin("./font/Bitter-Regular.ttf")
# generator = GenVin("./font/SimHei.ttf")
generator.gen_Batch(3, test_gen_path)
2 changes: 2 additions & 0 deletions readme.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# vin recognize
> 车辆 vin 码识别
## 关键在于样本的生成

2 changes: 0 additions & 2 deletions src/net_work.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
import tensorflow as tf
from lib.gen.generator import GenVin
import numpy as np
from lib.logger import logger


class NetWork():

Expand Down
Loading

0 comments on commit 18659b7

Please sign in to comment.