Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
lithium0003 committed Aug 25, 2023
1 parent 26073b3 commit 5837286
Showing 1 changed file with 24 additions and 18 deletions.
42 changes: 24 additions & 18 deletions convert1_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import onnx
import onnxruntime
import numpy as np
from numpy.lib.stride_tricks import as_strided
from PIL import Image
import matplotlib
matplotlib.use('Agg')
Expand All @@ -12,6 +13,21 @@

import net

def maxpool2d(input_matrix, kernel_size):
# Padding
pad_size = kernel_size // 2
pad = (pad_size, pad_size)
input_matrix = np.pad(input_matrix, [pad]*len(input_matrix.shape), constant_values=-np.inf)

# Window view of input_matrix
output_shape = (input_matrix.shape[0] - kernel_size + 1,
input_matrix.shape[1] - kernel_size + 1)
kernel_size = (kernel_size, kernel_size)
input_matrix_w = as_strided(input_matrix, shape = output_shape + kernel_size,
strides = input_matrix.strides + input_matrix.strides)
input_matrix_w = input_matrix_w.reshape(-1, *kernel_size)
return input_matrix_w.max(axis=(1,2)).reshape(output_shape)

class TextDetectorModel(tf.keras.models.Model):
def __init__(self, **kwargs):
super().__init__(**kwargs)
Expand All @@ -37,27 +53,14 @@ def convert1(ckpt_dir='ckpt1'):
inputs = tf.keras.Input(shape=(net.height,net.width,3), name='image_input')
heatmap, feature = model.detector(inputs)

keymap = heatmap[...,0]
keymape = tf.expand_dims(keymap, axis=-1)
local_peak = tf.nn.max_pool2d(keymape,5,1,'SAME')
local_peak = local_peak[...,0]

textlines = heatmap[...,5]
separator = heatmap[...,6]
xsize = heatmap[...,1]
ysize = heatmap[...,2]
xoffset = heatmap[...,3]
yoffset = heatmap[...,4]
code_map = []
for k in range(4):
code_map.append(heatmap[...,7+k])

outputs = [
tf.keras.layers.Lambda(lambda x: x, name='maps', dtype='float32')(tf.stack([keymap, local_peak, xsize, ysize, xoffset, yoffset, textlines, separator, *code_map], axis=-1)),
tf.keras.layers.Lambda(lambda x: x, name='maps', dtype='float32')(heatmap),
tf.keras.layers.Lambda(lambda x: x, name='feature', dtype='float32')(feature),
]
detector = tf.keras.Model(inputs, outputs, name='TextDetector')

# input_signature = [tf.TensorSpec([1] + list(inputs.shape[1:]), tf.float32, name='image_input')]
# tf2onnx.convert.from_keras(detector, input_signature=input_signature, output_path='TextDetector.onnx')
tf2onnx.convert.from_keras(detector, output_path='TextDetector.onnx')

onnx.checker.check_model('TextDetector.onnx')
Expand All @@ -71,6 +74,8 @@ def convert1(ckpt_dir='ckpt1'):
outputs.append(tf.keras.layers.Lambda(lambda x: x, name='mod_%d'%mod_id, dtype='float32')(decoder_id))
decoder = tf.keras.Model(embedded, outputs, name='CodeDecoder')

# input_signature = [tf.TensorSpec([1] + list(embedded.shape[1:]), tf.float32, name='feature_input')]
# tf2onnx.convert.from_keras(decoder, input_signature=input_signature, output_path='CodeDecoder.onnx')
tf2onnx.convert.from_keras(decoder, output_path='CodeDecoder.onnx')

onnx.checker.check_model('CodeDecoder.onnx')
Expand Down Expand Up @@ -155,8 +160,9 @@ def test_model():
# print(session_output.name, session_output.shape)

maps, feature = onnx_detector.run(['maps','feature'], {'image_input': image_input})
peakmap = 1/(1 + np.exp(-maps[0,:,:,0]))
peakmap = np.where(maps[0,:,:,0] == maps[0,:,:,1], peakmap, 0.)
peaks = maps[0,:,:,0] + np.random.normal(0,1e-3,size=(net.height//net.scale,net.width//net.scale))
maxvalue = maxpool2d(peaks, 5)
peakmap = np.where(peaks == maxvalue, 1/(1 + np.exp(-maps[0,:,:,0])), 0.)
idxy, idxx = np.unravel_index(np.argsort(-peakmap.ravel()), peakmap.shape)
results_dict = []
for y, x in zip(idxy, idxx):
Expand Down

0 comments on commit 5837286

Please sign in to comment.