-
Notifications
You must be signed in to change notification settings - Fork 18.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added prototype for ctc decoder and implemented ctc_loss_layer with t…
…ests. Added reverse layer (usefull for bidirectional recurrent layers, e.g. BLSTM), finished working on CTC-Loss-Layer, more tests. Separated forward and backward pass by introducing new intermediate variables (e.g. alpha and beta). CTCDecoderLayer: added scores and optional accuracy as top blobs. Implemented CTCDecoderLayerTest for GreedyDecoder. Added parameters to ctc decoder layer into proto. Added dummy example to ctc examples. Added an example to show the progress of learning. Fixed lint errors, made layout changes
- Loading branch information
Showing
20 changed files
with
3,004 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
# ignore generated data | ||
*.h5 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# CTC Examples | ||
|
||
## Dummy Example | ||
|
||
The 'dummy' example shows the basic usage of the CTC Layer combined with LSTM networks. Use the `make_dummy_data.py` script to create a single batch of data that will be stored in a hdf5 file. Call the solver prototxts to run the examples. | ||
|
||
The solver will overfit the network on the data. Therefore the loss will shrink constantly (more or less fast). | ||
|
||
### LSTM | ||
|
||
The `dummy_lstm_solver.prototxt` implements a standard one directional LSTM model. | ||
|
||
### BLSTM | ||
|
||
The `dummy_blstm_solver.prototxt` implements a bidirectional LSTM model using the ReverseLayer. | ||
|
||
### Plot progress of learning | ||
|
||
The `plot_dummy_progress.py` script will show the learning progress, i.e. the input probabilities and their diffs for each label over time. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
|
||
layer { | ||
name: "data" | ||
type: "HDF5Data" | ||
top: "data" | ||
top: "seq_ind" | ||
top: "labels" | ||
hdf5_data_param { | ||
source: "dummy_data.txt" | ||
batch_size: 40 | ||
} | ||
} | ||
|
||
layer { | ||
name: "lstm1" | ||
type: "LSTM" | ||
bottom: "data" | ||
bottom: "seq_ind" | ||
top: "lstm1" | ||
recurrent_param { | ||
# num + 1 for blank label! (last one) | ||
num_output: 40 | ||
weight_filler { | ||
type: "gaussian" | ||
std: 0.1 | ||
} | ||
bias_filler { | ||
type: "constant" | ||
} | ||
} | ||
} | ||
|
||
layer { | ||
name: "ip1" | ||
type: "InnerProduct" | ||
bottom: "lstm1" | ||
top: "ip1" | ||
inner_product_param { | ||
num_output: 6 | ||
weight_filler { | ||
type: "gaussian" | ||
std: 0.1 | ||
} | ||
axis: 2 | ||
} | ||
} | ||
|
||
layer { | ||
name: "rev1" | ||
type: "Reverse" | ||
bottom: "data" | ||
top: "rev_data" | ||
} | ||
|
||
layer { | ||
name: "lstm2" | ||
type: "LSTM" | ||
bottom: "rev_data" | ||
bottom: "seq_ind" | ||
top: "lstm2" | ||
recurrent_param { | ||
# num + 1 for blank label! (last one) | ||
num_output: 40 | ||
weight_filler { | ||
type: "gaussian" | ||
std: 0.1 | ||
} | ||
bias_filler { | ||
type: "constant" | ||
} | ||
} | ||
} | ||
|
||
|
||
layer { | ||
name: "ip2" | ||
type: "InnerProduct" | ||
bottom: "lstm2" | ||
top: "ip2" | ||
inner_product_param { | ||
num_output: 6 | ||
weight_filler { | ||
type: "gaussian" | ||
std: 0.1 | ||
} | ||
axis: 2 | ||
} | ||
} | ||
|
||
layer { | ||
name: "rev2" | ||
type: "Reverse" | ||
bottom: "ip2" | ||
top: "rev2" | ||
} | ||
|
||
layer { | ||
name: "eltwise-sum" | ||
type: "Eltwise" | ||
bottom: "ip1" | ||
bottom: "rev2" | ||
eltwise_param { operation: SUM } | ||
top: "sum" | ||
} | ||
|
||
layer { | ||
name: "loss" | ||
type: "CTCLoss" | ||
bottom: "sum" | ||
bottom: "seq_ind" | ||
bottom: "labels" | ||
top: "ctc_loss" | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
net: "dummy_blstm_net.prototxt" | ||
base_lr: 0.01 | ||
lr_policy: "fixed" | ||
display: 100 | ||
max_iter: 100000 | ||
solver_mode: CPU | ||
average_loss: 1 | ||
solver_type: SGD | ||
random_seed: 9602 | ||
clip_gradients: 10 | ||
debug_info: false |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
dummy_data.h5 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
|
||
layer { | ||
name: "data" | ||
type: "HDF5Data" | ||
top: "data" | ||
top: "seq_ind" | ||
top: "labels" | ||
hdf5_data_param { | ||
source: "dummy_data.txt" | ||
batch_size: 40 | ||
} | ||
} | ||
|
||
layer { | ||
name: "lstm1" | ||
type: "LSTM" | ||
bottom: "data" | ||
bottom: "seq_ind" | ||
top: "lstm1" | ||
recurrent_param { | ||
# num + 1 for blank label! (last one) | ||
num_output: 40 | ||
weight_filler { | ||
type: "gaussian" | ||
std: 0.1 | ||
} | ||
bias_filler { | ||
type: "constant" | ||
} | ||
} | ||
} | ||
|
||
layer { | ||
name: "ip1" | ||
type: "InnerProduct" | ||
bottom: "lstm1" | ||
top: "ip1" | ||
inner_product_param { | ||
num_output: 6 | ||
weight_filler { | ||
type: "gaussian" | ||
std: 0.1 | ||
} | ||
axis: 2 | ||
} | ||
} | ||
|
||
layer { | ||
name: "loss" | ||
type: "CTCLoss" | ||
bottom: "ip1" | ||
bottom: "seq_ind" | ||
bottom: "labels" | ||
top: "ctc_loss" | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
net: "dummy_lstm_net.prototxt" | ||
base_lr: 0.01 | ||
lr_policy: "fixed" | ||
display: 100 | ||
max_iter: 100000 | ||
solver_mode: CPU | ||
average_loss: 1 | ||
solver_type: SGD | ||
random_seed: 9602 | ||
clip_gradients: 10 | ||
debug_info: false |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
import numpy as np | ||
import h5py | ||
|
||
def store_hdf5(filename, mapping): | ||
"""Function to store data mapping to a hdf5 file | ||
Args: | ||
filename (str): The output filename | ||
mapping (dic): A dictionary containing mapping from name to numpy data | ||
The complete mapping will be stored as single datasets | ||
in the h5py file. | ||
""" | ||
|
||
print("Storing hdf5 file %s" % filename) | ||
with h5py.File(filename, 'w') as hf: | ||
for label, data in mapping.items(): | ||
print(" adding dataset %s with shape %s" % (label, data.shape)) | ||
hf.create_dataset(label, data=data) | ||
|
||
print(" finished") | ||
|
||
def generate_data(T_, C_, lab_len_): | ||
"""Function to generate dummy data | ||
The data is generated non randomly by a defined function. | ||
The sequence length is exactly T_. | ||
The target label sequence will be [0 1 2 ... (lab_len_-1)]. | ||
Args: | ||
T_ (int): The number of timesteps (this value must match the batch_size of the caffe net) | ||
C_ (int): The number of channgels/labels | ||
lab_len_(int): The label size that must be smaller or equals T_. This value | ||
will also be used as the maximum allowed label. The label size in the network | ||
must therefore be 6 = 5 + 1 (+1 for blank label) | ||
Returns: | ||
data (numpy array): A numpy array of shape (T_, 1, C_) containing dummy data | ||
sequence_indicators (numpy array): A numpy array of shape (T_, 1) indicating the | ||
sequence | ||
labels (numpy array): A numpy array of shape (T_, 1) defining the label sequence. | ||
labels will be -1 for all elements greater than T_ (indicating end of sequence). | ||
""" | ||
assert(lab_len_ <= T_) | ||
|
||
data = np.zeros((T_, 1, C_), dtype=np.float32) | ||
|
||
# this is an arbitrary function to generate data not randomly | ||
for t in range(T_): | ||
for c in range(C_): | ||
data[t,0,c] = ((c * 0.1 / C_)**2 - 0.25 + t * 0.2 / T_) * (int(T_ / 5)) / T_ | ||
|
||
# The sequence length is exactly T_. | ||
sequence_indicators = np.full((T_, 1), 1, dtype=np.float32) | ||
sequence_indicators[0] = 0 | ||
|
||
# The label lengh is lab_len_ | ||
# The output sequence is [0 1 2 ... lab_len_-1] | ||
labels = np.full((T_, 1), -1, dtype=np.float32) | ||
labels[0:lab_len_, 0] = range(lab_len_) | ||
|
||
return data, sequence_indicators, labels | ||
|
||
|
||
if __name__=="__main__": | ||
# generate the dummy data | ||
# not that T_ = 40 must match the batch_size of 40 in the network setup | ||
# as required by the CTC alorithm to see the full sequence | ||
# The label length and max label is set to 5. Use 6 = 5 + 1 for the label size in the network | ||
# to add the blank label | ||
data, sequence_indicators, labels = generate_data(40, 20, 5) | ||
|
||
# and write it to the h5 file | ||
store_hdf5("dummy_data.h5", {"data" : data, "seq_ind" : sequence_indicators, "labels" : labels}) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import caffe | ||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
import matplotlib.animation as animation | ||
|
||
solver = caffe.SGDSolver("dummy_blstm_solver.prototxt") | ||
|
||
data = solver.net.blobs['sum'].data | ||
shape = data.shape | ||
|
||
t = range(shape[0]) | ||
|
||
fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True) | ||
|
||
ax1.set_ylabel('label prediction probabilities') | ||
ax2.set_ylabel('gradients') | ||
ax1.set_xlabel('Time') | ||
ax2.set_xlabel('Time') | ||
data_p = [ax1.plot(t, np.sin(t))[0] for _ in range(shape[2])] | ||
diff_p = [ax2.plot(t, np.cos(t))[0] for _ in range(shape[2])] | ||
|
||
def init(): | ||
for i in range(shape[2]): | ||
data_p[i].set_ydata(np.ma.array(np.zeros(shape[0]), mask=True)) | ||
diff_p[i].set_ydata(np.ma.array(np.zeros(shape[0]), mask=True)) | ||
|
||
return data_p + diff_p | ||
|
||
def update_plot(data, diff): | ||
for i in range(shape[2]): | ||
data_p[i].set_ydata(data[:,0,i]) | ||
diff_p[i].set_ydata(diff[:,0,i]) | ||
|
||
return data_p + diff_p | ||
|
||
|
||
def animate(i): | ||
solver.step(100) | ||
data = solver.net.blobs['sum'].data | ||
diff = solver.net.blobs['sum'].diff | ||
ax1.relim() | ||
ax1.autoscale_view(True, True, True) | ||
ax2.relim() | ||
ax2.autoscale_view(True, True, True) | ||
return update_plot(data, diff) | ||
|
||
|
||
ani = animation.FuncAnimation(fig, animate, np.arange(1, 100), init_func=init, | ||
interval=1000, blit=True) | ||
plt.show() | ||
|
||
|
||
|
Oops, something went wrong.