Skip to content

Commit

Permalink
🐛 2021年7月4日16:08:12 修复预训练模型会下载失败的问题
Browse files Browse the repository at this point in the history
  • Loading branch information
ExcaliburEX committed Jul 4, 2021
1 parent d414dcd commit 66d9c9f
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 12 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
---

最新release:
<h3 align="center"><a href= 'https://github.com/ExcaliburEX/FingerveinRecognition/releases/download/V1.0/FingerveinRecogntion.exe'><img alt="GitHub release (latest by date)" src="https://img.shields.io/github/downloads/ExcaliburEX/FingerveinRecogntion/V1.0/total?color=success&flat-square&logo=Cachet"></a></h3>
<h3 align="center"><a href= 'https://github.com/ExcaliburEX/FingerveinRecognition/releases/download/V1.1/FingerveinRecogntion.exe'><img alt="GitHub release (latest by date)" src="https://img.shields.io/github/downloads/ExcaliburEX/FingerveinRecognition/V1.1/total?color=success&flat-square&logo=Cachet"></a></h3>

# 更新日志
- 2021-07-04:因为VGG19下载的预训练模型在Github导致有可能下载失败,因此将模型文件存储到自己的腾讯云OSS桶,通过程序自动下载,手动设置了VGG19的预训练模型的路径,规避了因为网络问题导致的下载失败。

# 使用说明

Expand Down
116 changes: 105 additions & 11 deletions finger.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
from tensorflow.compat.v1 import logging as log
log.set_verbosity(log.ERROR)
import time
from keras.applications.vgg16 import VGG16, preprocess_input
from keras.applications.vgg19 import VGG19, preprocess_input
from keras.preprocessing import image
from PIL import Image
import numpy as np
import random
import ssl
# import ssl


def get_feature(path,model):
Expand All @@ -38,6 +38,7 @@ def cos_sim(a, b):
import PySimpleGUI as sg
from io import BytesIO
import base64
import requests

def gaborfilter(I, S, F, W, P):
size = int(1.5 / S)
Expand Down Expand Up @@ -133,7 +134,7 @@ def Train(file,flag,window,model):
f.close()
return feature,1
else:
print('\n',time.strftime("[%Y-%m-%d %H:%M:%S]: ",time.localtime()), "图片 %s 已识别过..."%(name))
print(time.strftime("\n[%Y-%m-%d %H:%M:%S]: ",time.localtime()), "图片 %s 已识别过..."%(name))
return None,2

def Test(trainPic, readPic, finger_num,num,window,model):
Expand Down Expand Up @@ -220,18 +221,22 @@ def GUI():
layout = [
[
sg.Text('单人手指图片个数:', font=("KaiTi", 12),justification='left',relief=sg.RELIEF_RIDGE),
# sg.InputText('7', font=("KaiTi", 12), size=(17, 2), key='-RANGE1-'),
sg.InputText('', font=("KaiTi", 12), size=(17, 2), key='-RANGE1-'),
sg.Text('随机选取的图片测试个数:', font=("KaiTi", 12),justification='left',relief=sg.RELIEF_RIDGE),
# sg.InputText('5', font=("KaiTi", 12), size=(17, 2), key='-RANGE2-')
sg.InputText('', font=("KaiTi", 12), size=(17, 2), key='-RANGE2-')
],
[
sg.Text('训练集图片文件夹:', font=("KaiTi", 12),justification='left',relief=sg.RELIEF_RIDGE),
sg.In(size=(60, 1), enable_events=True, key="-FOLDER1-"),
# sg.In('C:/Users/Excalibur/Desktop/Test/V1.0/HighGuardFinger',size=(60, 1), enable_events=True, key="-FOLDER1-"),
sg.In('',size=(60, 1), enable_events=True, key="-FOLDER1-"),
sg.FolderBrowse('浏览', button_color=('Lavender', 'BlueViolet'), font=("KaiTi", 10),size=(8, 1))
],
[
sg.Text('待识别图片文件夹:', font=("KaiTi", 12),justification='left',relief=sg.RELIEF_RIDGE),
sg.In(size=(60, 1), enable_events=True, key="-FOLDER2-"),
# sg.In('C:/Users/Excalibur/Desktop/Test/V1.0/HighGuardTest',size=(60, 1), enable_events=True, key="-FOLDER2-"),
sg.In('',size=(60, 1), enable_events=True, key="-FOLDER2-"),
sg.FolderBrowse('浏览', button_color=('Lavender', 'BlueViolet'), font=("KaiTi", 10),size=(8, 1))
],
[sg.Column(col, background_color='papayawhip', size=(
Expand All @@ -244,7 +249,7 @@ def GUI():
sg.Exit('退出', button_color=('white', 'firebrick4'),key='Exit', size=(37, 1), font=("KaiTi", 12))
]
]
window = sg.Window('HighGuardFingerRecV1.0', layout,
window = sg.Window('HighGuardFingerRecV1.1', layout,
default_element_size=(80, 1), resizable=True, element_justification='center', text_justification='center',finalize=True)


Expand All @@ -269,30 +274,119 @@ def GUI():



# 判斷模型文件是否存在
def FileExist():
filename = 'vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5'
# username = getpass.getuser()
# fileloc = r'C:\Users\\' + username + '\.keras\models\\'
fileloc = os.getcwd() + '\\'
fullname = fileloc + filename
if os.path.exists(fullname):
return True,fileloc
else:
return False,fileloc

# 手动下载模型文件
def progressbar(url,path,file_name,window):
if not os.path.exists(path): # 看是否有该文件夹,没有则创建文件夹
os.mkdir(path)
file_path = os.path.join(path, file_name)
start = time.time() #下载开始时间
response = requests.get(url, stream=True)
size = 0 #初始化已下载大小
chunk_size = 1024 * 1024 * 3 # 每次下载的数据大小
content_size = int(response.headers['content-length']) # 下载文件总大小
count_tmp = 0
count = 0
try:
if response.status_code == 200: #判断是否响应成功
print(time.strftime("[%Y-%m-%d %H:%M:%S]: ",time.localtime()),'开始下载,[模型-大小] %s : %.2f MB'%(file_name,content_size / chunk_size * 3)) #开始下载,显示下载文件大小
with open(file_name,'wb') as file: #显示进度条
for data in response.iter_content(chunk_size = chunk_size):
file.write(data)
count += len(data)
size += len(data)
speed = (count - count_tmp) / 1024 / 1024
count_tmp = count
end = time.time() #下载结束时间
rate = int((size / content_size) * 20)
len1 = '>' * rate
len2 = '_' * (20-rate)
window['OUTPUT'].update(value=time.strftime("[%Y-%m-%d %H:%M:%S]: ",time.localtime()) + '模型总大小:%.2f MB,下载速度:%.2f M/S \n'%(content_size / chunk_size * 3,speed) +
time.strftime("[%Y-%m-%d %H:%M:%S]: ",time.localtime()) + '下载进度: [%s%s] |%d / %d| %.2f%%' % (len1,len2,size,content_size, float(size / content_size * 100)) +
time.strftime("\n[%Y-%m-%d %H:%M:%S]: ",time.localtime()) + '已经耗时%.2f秒'%(end - start))
time.sleep(0.1)
end = time.time() #下载结束时间
print(time.strftime("\n[%Y-%m-%d %H:%M:%S]: ",time.localtime()),'模型下载成功!,耗时: %.2f秒' % (end - start)) #输出下载用时时间
return True
except:
end = time.time() #下载结束时间
print(time.strftime("[%Y-%m-%d %H:%M:%S]: ",time.localtime()),'模型下载失败!,耗时: %.2f秒,请检查网络及相关设置再继续训练' % (end - start)) #输出下载用时时间
return False


# 暂时放弃的第一版下载方案
# def DownloadFile2(model_url, save_url,file_name):
# try:
# if model_url is None or save_url is None or file_name is None:
# print('参数错误')
# return None
# folder = os.path.exists(save_url)
# if not folder:
# os.makedirs(save_url)
# res = requests.get(model_url,stream=True)
# total_size = int(int(res.headers["Content-Length"])/1024+0.5)
# file_path = os.path.join(save_url, file_name)
# from tqdm import tqdm
# with open(file_path, 'wb') as fd:
# print('开始下载文件:{},当前文件大小:{}KB'.format(file_name,total_size))
# for chunk in tqdm(iterable=res.iter_content(1024),total=total_size,unit='k',desc=None):
# fd.write(chunk)
# print(file_name+' 下载完成!')
# except:
# print("程序错误")






# 批量训练与识别
def BatchTrain(directory,window):
model_name = 'VGG19'
model_full_name = 'vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5'
print(time.strftime("[%Y-%m-%d %H:%M:%S]: ",time.localtime()), "正在加载识别模型..." )
ssl._create_default_https_context = ssl._create_unverified_context # 下载模型的时候不想进行ssl证书校验
# ssl._create_default_https_context = ssl._create_unverified_context # 下载模型的时候不想进行ssl证书校验
starttime = time.time()
model = VGG16(weights='imagenet', include_top=False)
file_name = 'vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5'
model_url = 'https://blog-1259799643.cos.ap-shanghai.myqcloud.com/' + file_name
flag,loc = FileExist()
if flag == False:
print(time.strftime("[%Y-%m-%d %H:%M:%S]: ",time.localtime()), "初次使用需要下载 %s 预训练模型!"%(model_name))
downloadflag = progressbar(model_url, loc, file_name,window)
if downloadflag == False:
return
model = VGG19(weights=os.getcwd() + '\\' + model_full_name, include_top=False)
elapse = time.time() - starttime
print(time.strftime("[%Y-%m-%d %H:%M:%S]: ",time.localtime()), "加载完毕,耗时:%.3fs" % (elapse))
filelist = os.listdir(directory)
cnt = 1
starttime2 = time.time()
for f in filelist:
window['OUTPUT'].update(value=time.strftime("[%Y-%m-%d %H:%M:%S]: ",time.localtime()) + "正在训练第 %d / %d 个图片..."%(cnt,len(filelist)))
elapse2 = time.time() - starttime2
window['OUTPUT'].update(value=time.strftime("[%Y-%m-%d %H:%M:%S]: ",time.localtime()) + "正在训练第 %d / %d 个图片: %s,已耗时:%.2fs,单张平均耗时:%.2fs ..."%(cnt,len(filelist),f,elapse2,elapse2/cnt))
Train(directory + '/' + f,1,window,model)
cnt += 1
elapse2 = time.time() - starttime2
print('\n',time.strftime("[%Y-%m-%d %H:%M:%S]: ",time.localtime()), "训练完毕,耗时:%.3fs" % (elapse2))
print(time.strftime("\n[%Y-%m-%d %H:%M:%S]: ",time.localtime()), "训练完毕,耗时:%.3fs" % (elapse2))


def BatchTest(trainPic, readPicDir, finger_num,num,window):
model_full_name = 'vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5'
print(time.strftime("[%Y-%m-%d %H:%M:%S]: ",time.localtime()), "单人手指数:%d张,随机抽取其中 %d 张进行比对!"%(int(finger_num),int(num)) )
print(time.strftime("[%Y-%m-%d %H:%M:%S]: ",time.localtime()), "正在加载识别模型..." )
starttime = time.time()
model = VGG16(weights='imagenet', include_top=False)
model = VGG19(weights=os.getcwd() + '\\' + model_full_name, include_top=False)
elapse = time.time() - starttime
print(time.strftime("[%Y-%m-%d %H:%M:%S]: ",time.localtime()), "加载完毕,耗时:%.3fs" % (elapse))
filelist = os.listdir(readPicDir)
Expand Down

0 comments on commit 66d9c9f

Please sign in to comment.