diff --git a/examples/ctc/.gitignore b/examples/ctc/.gitignore new file mode 100644 index 00000000000..72b14116ab6 --- /dev/null +++ b/examples/ctc/.gitignore @@ -0,0 +1,2 @@ +# ignore generated data +*.h5 diff --git a/examples/ctc/README.md b/examples/ctc/README.md new file mode 100644 index 00000000000..19ec09def2f --- /dev/null +++ b/examples/ctc/README.md @@ -0,0 +1,19 @@ +# CTC Examples + +## Dummy Example + +The 'dummy' example shows the basic usage of the CTC Layer combined with LSTM networks. Use the `make_dummy_data.py` script to create a single batch of data that will be stored in a hdf5 file. Call the solver prototxts to run the examples. + +The solver will overfit the network on the data. Therefore the loss will shrink constantly (more or less fast). + +### LSTM + +The `dummy_lstm_solver.prototxt` implements a standard one directional LSTM model. + +### BLSTM + +The `dummy_blstm_solver.prototxt` implements a bidirectional LSTM model using the ReverseLayer. + +### Plot progress of learning + +The `plot_dummy_progress.py` script will show the learning progress, i.e. the input probabilities and their diffs for each label over time. diff --git a/examples/ctc/dummy/dummy_blstm_net.prototxt b/examples/ctc/dummy/dummy_blstm_net.prototxt new file mode 100644 index 00000000000..4e834f3878f --- /dev/null +++ b/examples/ctc/dummy/dummy_blstm_net.prototxt @@ -0,0 +1,113 @@ + +layer { + name: "data" + type: "HDF5Data" + top: "data" + top: "seq_ind" + top: "labels" + hdf5_data_param { + source: "dummy_data.txt" + batch_size: 40 + } +} + +layer { + name: "lstm1" + type: "LSTM" + bottom: "data" + bottom: "seq_ind" + top: "lstm1" + recurrent_param { + # num + 1 for blank label! (last one) + num_output: 40 + weight_filler { + type: "gaussian" + std: 0.1 + } + bias_filler { + type: "constant" + } + } +} + +layer { + name: "ip1" + type: "InnerProduct" + bottom: "lstm1" + top: "ip1" + inner_product_param { + num_output: 6 + weight_filler { + type: "gaussian" + std: 0.1 + } + axis: 2 + } +} + +layer { + name: "rev1" + type: "Reverse" + bottom: "data" + top: "rev_data" +} + +layer { + name: "lstm2" + type: "LSTM" + bottom: "rev_data" + bottom: "seq_ind" + top: "lstm2" + recurrent_param { + # num + 1 for blank label! (last one) + num_output: 40 + weight_filler { + type: "gaussian" + std: 0.1 + } + bias_filler { + type: "constant" + } + } +} + + +layer { + name: "ip2" + type: "InnerProduct" + bottom: "lstm2" + top: "ip2" + inner_product_param { + num_output: 6 + weight_filler { + type: "gaussian" + std: 0.1 + } + axis: 2 + } +} + +layer { + name: "rev2" + type: "Reverse" + bottom: "ip2" + top: "rev2" +} + +layer { + name: "eltwise-sum" + type: "Eltwise" + bottom: "ip1" + bottom: "rev2" + eltwise_param { operation: SUM } + top: "sum" +} + +layer { + name: "loss" + type: "CTCLoss" + bottom: "sum" + bottom: "seq_ind" + bottom: "labels" + top: "ctc_loss" +} diff --git a/examples/ctc/dummy/dummy_blstm_solver.prototxt b/examples/ctc/dummy/dummy_blstm_solver.prototxt new file mode 100644 index 00000000000..ddf7c15623c --- /dev/null +++ b/examples/ctc/dummy/dummy_blstm_solver.prototxt @@ -0,0 +1,11 @@ +net: "dummy_blstm_net.prototxt" +base_lr: 0.01 +lr_policy: "fixed" +display: 100 +max_iter: 100000 +solver_mode: CPU +average_loss: 1 +solver_type: SGD +random_seed: 9602 +clip_gradients: 10 +debug_info: false diff --git a/examples/ctc/dummy/dummy_data.txt b/examples/ctc/dummy/dummy_data.txt new file mode 100644 index 00000000000..32a11d6e667 --- /dev/null +++ b/examples/ctc/dummy/dummy_data.txt @@ -0,0 +1 @@ +dummy_data.h5 diff --git a/examples/ctc/dummy/dummy_lstm_net.prototxt b/examples/ctc/dummy/dummy_lstm_net.prototxt new file mode 100644 index 00000000000..370382901f5 --- /dev/null +++ b/examples/ctc/dummy/dummy_lstm_net.prototxt @@ -0,0 +1,55 @@ + +layer { + name: "data" + type: "HDF5Data" + top: "data" + top: "seq_ind" + top: "labels" + hdf5_data_param { + source: "dummy_data.txt" + batch_size: 40 + } +} + +layer { + name: "lstm1" + type: "LSTM" + bottom: "data" + bottom: "seq_ind" + top: "lstm1" + recurrent_param { + # num + 1 for blank label! (last one) + num_output: 40 + weight_filler { + type: "gaussian" + std: 0.1 + } + bias_filler { + type: "constant" + } + } +} + +layer { + name: "ip1" + type: "InnerProduct" + bottom: "lstm1" + top: "ip1" + inner_product_param { + num_output: 6 + weight_filler { + type: "gaussian" + std: 0.1 + } + axis: 2 + } +} + +layer { + name: "loss" + type: "CTCLoss" + bottom: "ip1" + bottom: "seq_ind" + bottom: "labels" + top: "ctc_loss" +} diff --git a/examples/ctc/dummy/dummy_lstm_solver.prototxt b/examples/ctc/dummy/dummy_lstm_solver.prototxt new file mode 100644 index 00000000000..c8a02bd2f04 --- /dev/null +++ b/examples/ctc/dummy/dummy_lstm_solver.prototxt @@ -0,0 +1,11 @@ +net: "dummy_lstm_net.prototxt" +base_lr: 0.01 +lr_policy: "fixed" +display: 100 +max_iter: 100000 +solver_mode: CPU +average_loss: 1 +solver_type: SGD +random_seed: 9602 +clip_gradients: 10 +debug_info: false diff --git a/examples/ctc/dummy/make_dummy_data.py b/examples/ctc/dummy/make_dummy_data.py new file mode 100644 index 00000000000..2511c5a0e15 --- /dev/null +++ b/examples/ctc/dummy/make_dummy_data.py @@ -0,0 +1,74 @@ +import numpy as np +import h5py + +def store_hdf5(filename, mapping): + """Function to store data mapping to a hdf5 file + + Args: + filename (str): The output filename + mapping (dic): A dictionary containing mapping from name to numpy data + The complete mapping will be stored as single datasets + in the h5py file. + """ + + print("Storing hdf5 file %s" % filename) + with h5py.File(filename, 'w') as hf: + for label, data in mapping.items(): + print(" adding dataset %s with shape %s" % (label, data.shape)) + hf.create_dataset(label, data=data) + + print(" finished") + +def generate_data(T_, C_, lab_len_): + """Function to generate dummy data + + The data is generated non randomly by a defined function. + The sequence length is exactly T_. + The target label sequence will be [0 1 2 ... (lab_len_-1)]. + + Args: + T_ (int): The number of timesteps (this value must match the batch_size of the caffe net) + C_ (int): The number of channgels/labels + lab_len_(int): The label size that must be smaller or equals T_. This value + will also be used as the maximum allowed label. The label size in the network + must therefore be 6 = 5 + 1 (+1 for blank label) + + Returns: + data (numpy array): A numpy array of shape (T_, 1, C_) containing dummy data + sequence_indicators (numpy array): A numpy array of shape (T_, 1) indicating the + sequence + labels (numpy array): A numpy array of shape (T_, 1) defining the label sequence. + labels will be -1 for all elements greater than T_ (indicating end of sequence). + """ + assert(lab_len_ <= T_) + + data = np.zeros((T_, 1, C_), dtype=np.float32) + + # this is an arbitrary function to generate data not randomly + for t in range(T_): + for c in range(C_): + data[t,0,c] = ((c * 0.1 / C_)**2 - 0.25 + t * 0.2 / T_) * (int(T_ / 5)) / T_ + + # The sequence length is exactly T_. + sequence_indicators = np.full((T_, 1), 1, dtype=np.float32) + sequence_indicators[0] = 0 + + # The label lengh is lab_len_ + # The output sequence is [0 1 2 ... lab_len_-1] + labels = np.full((T_, 1), -1, dtype=np.float32) + labels[0:lab_len_, 0] = range(lab_len_) + + return data, sequence_indicators, labels + + +if __name__=="__main__": + # generate the dummy data + # not that T_ = 40 must match the batch_size of 40 in the network setup + # as required by the CTC alorithm to see the full sequence + # The label length and max label is set to 5. Use 6 = 5 + 1 for the label size in the network + # to add the blank label + data, sequence_indicators, labels = generate_data(40, 20, 5) + + # and write it to the h5 file + store_hdf5("dummy_data.h5", {"data" : data, "seq_ind" : sequence_indicators, "labels" : labels}) + diff --git a/examples/ctc/dummy/plot_dummy_progress.py b/examples/ctc/dummy/plot_dummy_progress.py new file mode 100644 index 00000000000..2e6f76c5e0a --- /dev/null +++ b/examples/ctc/dummy/plot_dummy_progress.py @@ -0,0 +1,53 @@ +import caffe +import numpy as np +import matplotlib.pyplot as plt +import matplotlib.animation as animation + +solver = caffe.SGDSolver("dummy_blstm_solver.prototxt") + +data = solver.net.blobs['sum'].data +shape = data.shape + +t = range(shape[0]) + +fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True) + +ax1.set_ylabel('label prediction probabilities') +ax2.set_ylabel('gradients') +ax1.set_xlabel('Time') +ax2.set_xlabel('Time') +data_p = [ax1.plot(t, np.sin(t))[0] for _ in range(shape[2])] +diff_p = [ax2.plot(t, np.cos(t))[0] for _ in range(shape[2])] + +def init(): + for i in range(shape[2]): + data_p[i].set_ydata(np.ma.array(np.zeros(shape[0]), mask=True)) + diff_p[i].set_ydata(np.ma.array(np.zeros(shape[0]), mask=True)) + + return data_p + diff_p + +def update_plot(data, diff): + for i in range(shape[2]): + data_p[i].set_ydata(data[:,0,i]) + diff_p[i].set_ydata(diff[:,0,i]) + + return data_p + diff_p + + +def animate(i): + solver.step(100) + data = solver.net.blobs['sum'].data + diff = solver.net.blobs['sum'].diff + ax1.relim() + ax1.autoscale_view(True, True, True) + ax2.relim() + ax2.autoscale_view(True, True, True) + return update_plot(data, diff) + + +ani = animation.FuncAnimation(fig, animate, np.arange(1, 100), init_func=init, + interval=1000, blit=True) +plt.show() + + + diff --git a/include/caffe/layers/ctc_decoder_layer.hpp b/include/caffe/layers/ctc_decoder_layer.hpp new file mode 100644 index 00000000000..368843346bd --- /dev/null +++ b/include/caffe/layers/ctc_decoder_layer.hpp @@ -0,0 +1,226 @@ +#ifndef CAFFE_CTC_DECODER_LAYER_HPP_ +#define CAFFE_CTC_DECODER_LAYER_HPP_ + +#include +#include + +#include "caffe/blob.hpp" +#include "caffe/layer.hpp" +#include "caffe/proto/caffe.pb.h" + +namespace caffe { + +/** + * @brief A layer that implements the decoder for a ctc + * + * Bottom blob is the probability of label and the sequence indicators. + * + * TODO(dox): thorough documentation for Forward, Backward, and proto params. + */ +template +class CTCDecoderLayer : public Layer { + public: + typedef vector Sequence; + typedef vector Sequences; + + public: + explicit CTCDecoderLayer(const LayerParameter& param) + : Layer(param) + , blank_index_(param.ctc_decoder_param().blank_index()) + , merge_repeated_(param.ctc_decoder_param().ctc_merge_repeated()) { + } + + virtual void LayerSetUp(const vector*>& bottom, + const vector*>& top) { + // required additional output blob for accuracies + if (bottom.size() == 3) {CHECK_EQ(top.size(), 2);} + } + + virtual void Reshape(const vector*>& bottom, + const vector*>& top) { + Blob* scores = top[0]; + + const Blob* probabilities = bottom[0]; + T_ = probabilities->shape(0); + N_ = probabilities->shape(1); + C_ = probabilities->shape(2); + + output_sequences_.clear(); + output_sequences_.resize(N_); + + scores->Reshape(N_, 1, 1, 1); + + if (blank_index_ < 0) { + blank_index_ = C_ - 1; + } + + if (top.size() == 2) { + // Single accuracy as output + top[1]->Reshape(1, 1, 1, 1); + } + } + + virtual inline const char* type() const { return "CTCDecoder"; } + + // probabilities (T x N x C), + // sequence_indicators (T x N), + // target_sequences (T X N) [optional] + // if a target_sequence is provided, an additional accuracy top blob is + // required + virtual inline int MinBottomBlobs() const { return 2; } + virtual inline int MaxBottomBlobs() const { return 3; } + + // output scores, accuracy [optional, if target_sequences as bottom blob] + virtual inline int MinTopBlobs() const { return 1; } + virtual inline int MaxTopBlobs() const { return 2; } + + const Sequences& OutputSequences() const {return output_sequences_;} + + protected: + virtual void Forward_cpu(const vector*>& bottom, + const vector*>& top) { + const Blob* probabilities = bottom[0]; + const Blob* sequence_indicators = bottom[1]; + Blob* scores = top[0]; + + Decode(probabilities, sequence_indicators, &output_sequences_, scores); + + if (top.size() == 2) { + // compute accuracy + Dtype &acc = top[1]->mutable_cpu_data()[0]; + acc = 0; + + CHECK_GE(bottom.size(), 3); // required target sequences blob + const Blob* target_sequences_data = bottom[2]; + const Dtype* ts_data = target_sequences_data->cpu_data(); + for (int n = 0; n < N_; ++n) { + Sequence target_sequence; + for (int t = 0; t < T_; ++t) { + const Dtype dtarget = ts_data[target_sequences_data->offset(t, n)]; + if (dtarget < 0) { + // sequence has finished + break; + } + // round to int, just to be sure + const int target = static_cast(0.5 + dtarget); + target_sequence.push_back(target); + } + + const int ed = EditDistance(target_sequence, output_sequences_[n]); + + acc += ed * 1.0 / + std::max(target_sequence.size(), output_sequences_[n].size()); + } + + acc = 1 - acc / N_; + CHECK_GE(acc, 0); + CHECK_LE(acc, 1); + } + } + + virtual void Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) { + for (int i = 0; i < propagate_down.size(); ++i) { + if (propagate_down[i]) { NOT_IMPLEMENTED; } + } + } + + virtual void Decode(const Blob* probabilities, + const Blob* sequence_indicators, + Sequences* output_sequences, + Blob* scores) const = 0; + + int EditDistance(const Sequence &s1, const Sequence &s2) { + const size_t len1 = s1.size(); + const size_t len2 = s2.size(); + + Sequences d(len1 + 1, Sequence(len2 + 1)); + + d[0][0] = 0; + for (size_t i = 1; i <= len1; ++i) {d[i][0] = i;} + for (size_t i = 1; i <= len2; ++i) {d[0][i] = i;} + + for (size_t i = 1; i <= len1; ++i) { + for (size_t j = 1; j <= len2; ++j) { + d[i][j] = std::min( + std::min( + d[i - 1][j] + 1, + d[i][j - 1] + 1), + d[i - 1][j - 1] + (s1[i - 1] == s2[j - 1] ? 0 : 1)); + } + } + + return d[len1][len2]; + } + + protected: + Sequences output_sequences_; + int T_; + int N_; + int C_; + int blank_index_; + bool merge_repeated_; +}; + +template +class CTCGreedyDecoderLayer : public CTCDecoderLayer { + private: + using typename CTCDecoderLayer::Sequences; + using CTCDecoderLayer::T_; + using CTCDecoderLayer::N_; + using CTCDecoderLayer::C_; + using CTCDecoderLayer::blank_index_; + using CTCDecoderLayer::merge_repeated_; + + public: + explicit CTCGreedyDecoderLayer(const LayerParameter& param) + : CTCDecoderLayer(param) {} + + virtual inline const char* type() const { return "CTCGreedyDecoder"; } + + protected: + virtual void Decode(const Blob* probabilities, + const Blob* sequence_indicators, + Sequences* output_sequences, + Blob* scores) const { + CHECK_EQ(CHECK_NOTNULL(scores)->count(), N_); + Dtype* score_data = scores->mutable_cpu_data(); + for (int n = 0; n < N_; ++n) { + int prev_class_idx = -1; + score_data[n] = 0; + + for (int t = 0; /* check at end */; ++t) { + // get maximum probability and its index + int max_class_idx = 0; + const Dtype* probs = probabilities->cpu_data() + + probabilities->offset(t, n); + Dtype max_prob = probs[0]; + ++probs; + for (int c = 1; c < C_; ++c, ++probs) { + if (*probs > max_prob) { + max_class_idx = c; + max_prob = *probs; + } + } + + score_data[n] += -max_prob; + + if (max_class_idx != blank_index_ + && !(merge_repeated_&& max_class_idx == prev_class_idx)) { + output_sequences->at(n).push_back(max_class_idx); + } + + prev_class_idx = max_class_idx; + + if (t + 1 == T_ || sequence_indicators->data_at(t + 1, n, 0, 0) == 0) { + // End of sequence + break; + } + } + } + } +}; + +} // namespace caffe + +#endif // CAFFE_CTC_DECODER_LAYER_HPP_ diff --git a/include/caffe/layers/ctc_loss_layer.hpp b/include/caffe/layers/ctc_loss_layer.hpp new file mode 100644 index 00000000000..9110659f368 --- /dev/null +++ b/include/caffe/layers/ctc_loss_layer.hpp @@ -0,0 +1,270 @@ +#ifndef CAFFE_CTC_LOSS_LAYER_HPP +#define CAFFE_CTC_LOSS_LAYER_HPP + +#include +#include + +#include "caffe/layers/loss_layer.hpp" + +namespace caffe { +template + + +/** + * @brief Implementation of the CTC (Connectionist Temporal Classification) algorithm + * to label unsegmented sequence data with recurrent neural networks + * + * The input data is expected to follow the rules for the recurrent layers, meaning: + * T x N x L, where T is the time compontent, N the number of simulaneously computed + * input sequences and L is the size of the possible labels. There will be a softmax + * applied to the input data. No need to add a manual softmax layer. Note that L + * must be the size of your actual label count plus one. This last entry represents + * the required 'blank_index' for the algorith. + * + * The second input blob are the sequence indicators for the data with shape T x N. + * A 0 means the start of a sequence. See RecurrentLayer for additional information. + * + * The third input is the blob of the target sequence with shape T X N. The data + * is expected to contain the labeling of the target sequence and -1 if the sequence + * has ended. + * + * Sample input data for T = 10, N = 1 (this column is dropped in the data), C = 5 + * (the data is filled with dummy values), and the target sequence [012]. + * The input sequence has a length of 8. The number of labels is 4 (3 + 1). + * + * T | data | seq_ind | target_seq | + * -- | ----------- | ------- | ---------- | + * 0 | [1 5 2 2 5] | 0 | 0 | + * 1 | [3 3 2 3 4] | 1 | 2 | + * 2 | [7 3 3 5 4] | 1 | 1 | + * 3 | [0 5 3 2 5] | 1 | -1 | + * 4 | [0 4 1 2 4] | 1 | -1 | + * 5 | [2 4 3 5 7] | 1 | -1 | + * 6 | [3 4 1 3 4] | 1 | -1 | + * 7 | [8 4 2 4 4] | 1 | -1 | + * 8 | [0 0 0 0 0] | 0 | -1 | + * 9 | [0 0 0 0 0] | 0 | -1 | + * + * Note that the complete sequence must fit into a (time) batch and each sequence + * must start at t = 0 of that batch. + * + * To split the computation into Forward and Backward passes the intermediate results + * (alpha, beta, l_prime, log_p_z_x) are stored during the forward pass and are + * reused during the backward pass. + */ +class CTCLossLayer : public LossLayer { + public: + // double blob for storing probabilities with higher accuracy + typedef Blob ProbBlob; + + // alpha or beta variables are a probability blob + typedef ProbBlob CTCVariables; + + // blob for storing lengths (sequences) + typedef Blob LengthBlob; + + // blob for storing a sequence + typedef Blob SequenceBlob; + + // Vector storing the label sequences for each sequence + typedef vector LabelSequences; + + public: + explicit CTCLossLayer(const LayerParameter& param); + virtual ~CTCLossLayer(); + + virtual void LayerSetUp( + const vector*>& bottom, const vector*>& top); + virtual void Reshape( + const vector*>& bottom, const vector*>& top); + + virtual inline const char* type() const { return "CTCLoss"; } + + // probabilities, sequence indicators, target sequence + virtual inline int ExactNumBottomBlobs() const { return 3; } + + // loss + virtual inline int ExactNumTopBlobs() const { return 1; } + + // access to internal calculation variables, + // used for testing intermediate states + void SetLossCalculationT(int t) {loss_calculation_t_ = t;} + const LengthBlob& SequenceLength() const {return seq_len_;} + const LengthBlob& LabelLength() const {return label_len_;} + const ProbBlob& LogPzx() const {return log_p_z_x_;} + const vector& LogAlpha() const {return log_alpha_;} + const vector& LogBeta() const {return log_beta_;} + const LabelSequences& LPrimes() const {return l_primes_;} + const vector*>& Y() const {return y_;} + + protected: + /** + * @brief Computes the loss and the error gradients for the input data + * in one step (due to optimization isses) + * + * @param bottom input Blob vector (length 3) + * -# @f$ (T \times N \times C) @f$ + * the inputs @f$ x @f$ + * -# @f$ (T \times N) @f$ + * the sequence indicators for the data + * (must be 0 at @f$ t = 0 @f$ and 1 during a sequence) + * -# @f$ (T \times N) @f$ + * the target sequence + * (must start at @f$ t = 0 @f$ and filled with -1 if the sequence has ended) + * @param top output Blob vector (length 1) + * -# @f$ (1) @f$ + * the computed loss + */ + + virtual void Forward_cpu(const vector*>& bottom, + const vector*>& top); + /** + * @brief Unused. Gradient calculation is done in Forward_cpu + */ + virtual void Backward_cpu(const vector*>& top, + const vector& propagate_down, + const vector*>& bottom); + + private: + /** + * @brief Calculates the loss of the given data batch + * @param seq_len_blob The blob containing the data sequence lengths + * @param target_seq_blob The blob containing the target sequence + * @param target_seq_len_blob The blob containing the target sequence lengths + * @param data_blob The blob containing the probability distribution + * @param preprocess_collapse_repeated See preprocess_collapse_repeated_ + * @param ctc_merge_repeated See ctc_merge_repeated_ + * @param loss The output of the loss + * @param requires_backprob Calculate the gradients + */ + void CalculateLoss(const LengthBlob *seq_len_blob, + const Blob *target_seq_blob, + const LengthBlob *target_seq_len_blob, + Blob *data_blob, + ProbBlob *log_p_z_x, + LabelSequences *l_primes, + vector*> *y, + bool preprocess_collapse_repeated, + bool ctc_merge_repeated, + Dtype *loss, + bool requires_backprob) const; + + /** + * @brief Calculates the forward variables of the CTC algorithm denoted by + * @f$ \alpha @f$ + * @param l_prime The target sequence with inserted blanks + * @param y The input probabilities for this sequence + * @param ctc_merge_repeated See ctc_merge_repeated_ + * @param log_alpha The output blob that stores the variables + */ + void CalculateForwardVariables(const SequenceBlob* l_prime, + const Blob* y, + bool ctc_merge_repeated, + CTCVariables* log_alpha) const; + + /** + * @brief Calculates the backward variables of the CTC algorithm denoted by + * @f$ \beta @f$ + * @param l_prime The target sequence with inserted blanks + * @param y The input probabilities for this sequence + * @param ctc_merge_repeated See ctc_merge_repeated_ + * @param log_beta The output blob that stores the variables + */ + void CalculateBackwardVariables(const SequenceBlob* l_prime, + const Blob* y, + bool ctc_merge_repeated, + CTCVariables* log_beta) const; + + /** + * @brief Calculate the gradient of the input variables + * @param b The number of the sequence in the input data + * @param seq_length The sequence length + * @param l_prime The target sequence with inserted blanks + * @param y_d_blob The softmax input variables for this specific sequence + * @param log_alpha The log values of the forward variables + * @param log_beta The log values of the backward variables + * @param log_p_z_x The computed probability of the path (corresponds to loss) + * @param y The input blob for this sequence. Here only the diff data will be used + * as output for the gradient. + */ + void CalculateGradient(int b, int seq_length, const SequenceBlob* l_prime, + const Blob* y_d_blob, + const CTCVariables* log_alpha, + const CTCVariables* log_beta, + double log_p_z_x, + Blob* y) const; + + /** + * @brief Computes the L' sequence of the target sequences + * @param preprocess_collapse_repeated See proprocess_collapse_repeated_ + * @param N The number of parallel sequences + * @param num_classes The number of allowed labels + * @param seq_len The sequence lengths (lengths of the input data) + * @param labels The labels blob + * @param label_len The length of the labels for each sequence + * @param max_u_prime Output of the maximum length of a target sequence + * @param l_primes Output of the label sequences + */ + void PopulateLPrimes(bool preprocess_collapse_repeated, + int N, + int num_classes, + const LengthBlob& seq_len, + const Blob& labels, + const LengthBlob& label_len, + int *max_u_prime, + LabelSequences* l_primes) const; + /** + * @brief Transform a sequence to the sequence with inserted blanks. + * @param l the default sequence + * @param l_prime the sequence with inserted blanks + * + * The length of the output will be |L'| = 2 |L| + 1. + * + * E.g. [0 1 4] -> [5 0 5 1 5 4 5] where 5 indicates the blank label. + * The number of classes is therefore 6 + */ + void GetLPrimeIndices(const vector& l, SequenceBlob* l_prime) const; + + int T_; + int N_; + int C_; + + int output_delay_; + int blank_index_; + + bool preprocess_collapse_repeated_; + bool ctc_merge_repeated_; + + /// the time for which to calculate the loss + /// see Graves Eq. (7.27) + /// Note that the result must be the same for each 0 <= t < T + /// Therefore you can chose an arbitrary value, default 0 + int loss_calculation_t_; + + // Intermediate variables that are calculated during the forward pass + // and reused during the backward pass + + // blob to store the sequence lengths (input data) + LengthBlob seq_len_; + + // blob to store the label lengths (target label sequence) + LengthBlob label_len_; + + // blob to store log(p(z|x)) for each batch + ProbBlob log_p_z_x_; + + // blobs to store the alpha and beta variables of each input sequence + // the algorithm will store the logarithm of these variables + vector log_alpha_; + vector log_beta_; + + // blobs to store the l_primes of the sequences + LabelSequences l_primes_; + + // blobs to store the intermediate softmax outputs + vector*> y_; +}; + +} // namespace caffe + +#endif // CAFFE_CTC_LOSS_LAYER_HPP diff --git a/include/caffe/layers/reverse_layer.hpp b/include/caffe/layers/reverse_layer.hpp new file mode 100644 index 00000000000..016bd9567a5 --- /dev/null +++ b/include/caffe/layers/reverse_layer.hpp @@ -0,0 +1,47 @@ +#ifndef REVERSE_LAYER_HPP +#define REVERSE_LAYER_HPP + +#include + +#include "caffe/blob.hpp" +#include "caffe/layer.hpp" +#include "caffe/proto/caffe.pb.h" + +#include "caffe/layers/neuron_layer.hpp" + +namespace caffe { + +/* + * @brief Reverses the data of the input Blob into the output blob. + * + * Note: This is a useful layer if you want to reverse the time of + * a recurrent layer. + */ + +template +class ReverseLayer : public NeuronLayer { + public: + explicit ReverseLayer(const LayerParameter& param); + + virtual void LayerSetUp(const vector*>& bottom, + const vector*>& top); + + virtual inline const char* type() const { return "Reverse"; } + + protected: + virtual void Forward_cpu(const vector*>& bottom, + const vector*>& top); + virtual void Forward_gpu(const vector*>& bottom, + const vector*>& top); + + virtual void Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + virtual void Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + + int axis_; +}; + +} // namespace caffe + +#endif // REVERSE_LAYER_HPP diff --git a/src/caffe/layers/ctc_decoder_layer.cpp b/src/caffe/layers/ctc_decoder_layer.cpp new file mode 100644 index 00000000000..174b0bf30d3 --- /dev/null +++ b/src/caffe/layers/ctc_decoder_layer.cpp @@ -0,0 +1,10 @@ +#include "caffe/layers/ctc_decoder_layer.hpp" + +namespace caffe { + +INSTANTIATE_CLASS(CTCDecoderLayer); + +INSTANTIATE_CLASS(CTCGreedyDecoderLayer); +REGISTER_LAYER_CLASS(CTCGreedyDecoder); + +} // namespace caffe diff --git a/src/caffe/layers/ctc_loss_layer.cpp b/src/caffe/layers/ctc_loss_layer.cpp new file mode 100644 index 00000000000..89d4a7e9552 --- /dev/null +++ b/src/caffe/layers/ctc_loss_layer.cpp @@ -0,0 +1,692 @@ +#include "caffe/layers/ctc_loss_layer.hpp" + +#include +#include +#include +#include +#include + +namespace caffe { + +/** + * @brief Implodes a vector into a string + * @param in The input vector + * @param delim The delimiter + * @return string out of the input vector + */ +std::string imploded(const std::vector &in, + const std::string &delim = " ") { + std::stringstream out; + out << in[0]; + for (int i = 1; i < in.size(); ++i) { + out << delim << in[i]; + } + return out.str(); +} + +/** + * @brief Converts a Dtype memory label sequence into a int vector + * @param label Pointer to the start of the sequence + * @param label_size The size of the label (number of elements to read from + * label) + * @param label_incr The offset to the next label in the sequence in the raw + * data input of label + * @return int vector containing the sequence + */ +template +vector extract_label_sequence(const Dtype* label, + int label_size, + int label_incr) { + vector out(label_size); + for (int i = 0; i < label_size; ++i) { + out[i] = static_cast(*label + 0.5); + label += label_incr; + } + return out; +} + +// Probability calculation utils. +// Note that only in double space. +// The c++ standard before c++11 does not support templates on variables yet. +// When setting c++11 to required standard add template and replace +// double. + +/// Zero probability in log space +static const double kLogZero = -std::numeric_limits::infinity(); + +/** + * @brief Adds two log probabilities. This equates a multiplication of + * probabilities in normal space. + * @param log_prob_1 The first probability + * @param log_prob_2 The second probability + * @returns The added log prob + */ +inline double LogSumExp(double log_prob_1, double log_prob_2) { + // Always have 'b' be the smaller number to avoid the exponential from + // blowing up. + if (log_prob_1 == kLogZero && log_prob_2 == kLogZero) { + return kLogZero; + } else { + return (log_prob_1 > log_prob_2) + ? log_prob_1 + log1p(exp(log_prob_2 - log_prob_1)) + : log_prob_2 + log1p(exp(log_prob_1 - log_prob_2)); + } +} + + +template +CTCLossLayer::CTCLossLayer(const LayerParameter& param) + : LossLayer(param) { + output_delay_ = param.ctc_loss_param().output_delay(); + blank_index_ = param.ctc_loss_param().blank_index(); + preprocess_collapse_repeated_ = + param.ctc_loss_param().preprocess_collapse_repeated(); + ctc_merge_repeated_ = param.ctc_loss_param().ctc_merge_repeated(); + loss_calculation_t_ = param.ctc_loss_param().loss_calculation_t(); +} + +template +CTCLossLayer::~CTCLossLayer() { + // clear alpha and beta blobs memory + for (int n = 0; n < N_; ++n) { + // dummy shapes + delete log_alpha_[n]; + delete log_beta_[n]; + delete l_primes_[n]; + delete y_[n]; + } + log_alpha_.clear(); + log_beta_.clear(); + l_primes_.clear(); + y_.clear(); +} + +template +void CTCLossLayer::LayerSetUp( + const vector*>& bottom, + const vector*>& top) { + LossLayer::LayerSetUp(bottom, top); + + loss_calculation_t_ = 0; + + const Blob* probs = bottom[0]; + const Blob* seq_ind = bottom[1]; + const Blob* label_seq = bottom[2]; + + + T_ = probs->num(); + N_ = probs->channels(); + C_ = probs->height(); + CHECK_EQ(probs->width(), 1); + + CHECK_EQ(T_, seq_ind->num()); + CHECK_EQ(N_, seq_ind->channels()); + CHECK_EQ(N_, label_seq->channels()); + + if (blank_index_ < 0) { + // select the last label as default label if user did not specify + // one with the blank_index parameter. + blank_index_ = C_ - 1; + } + + // resize data storage blobs for each sequence + seq_len_.Reshape(N_, 1, 1, 1); + label_len_.Reshape(N_, 1, 1, 1); + log_p_z_x_.Reshape(N_, 1, 1, 1); + + + // resize alpha and beta containers to the required input sequences length + log_alpha_.resize(N_); + log_beta_.resize(N_); + l_primes_.resize(N_); + y_.resize(N_); + + for (int n = 0; n < N_; ++n) { + // dummy shapes + log_alpha_[n] = new CTCVariables(1, 1, 1, 1); + log_beta_[n] = new CTCVariables(1, 1, 1, 1); + l_primes_[n] = new SequenceBlob(1, 1, 1, 1); + y_[n] = new Blob(1, 1, 1, 1); + } +} + +template +void CTCLossLayer::Reshape( + const vector*>& bottom, + const vector*>& top) { + LossLayer::Reshape(bottom, top); +} + +template +void CTCLossLayer::Forward_cpu(const vector*>& bottom, + const vector*>& top) { + Blob *data_blob = bottom[0]; + const Blob *seq_ind_blob = bottom[1]; + const Blob *target_seq_blob = bottom[2]; + + CHECK_EQ(data_blob->num(), seq_ind_blob->num()); + CHECK_EQ(data_blob->num(), target_seq_blob->num()); + + CHECK_EQ(data_blob->channels(), seq_ind_blob->channels()); + CHECK_EQ(data_blob->channels(), target_seq_blob->channels()); + + // compute the sequence length and label length + int* seq_len = seq_len_.mutable_cpu_data(); + int* label_len = label_len_.mutable_cpu_data(); + for (int n = 0; n < N_; ++n) { + seq_len[n] = T_; // default value is maximal allowed length + label_len[n] = T_; // default value is maximal allowed length + + const Dtype *seq = seq_ind_blob->cpu_data() + n; + const Dtype *label = target_seq_blob->cpu_data() + n; + + // sequence indicators start with seq == 0.0 to indicate the start of a + // sequence. Skip at t = 0, so start at t = 1 + seq += seq_ind_blob->channels(); + for (int t = 1; t < T_; ++t) { + if (static_cast(*seq + 0.5) == 0) { + seq_len[n] = t; + break; + } + seq += seq_ind_blob->channels(); + } + + // label indicators are negative if the sequence has ended + for (int t = 0; t < T_; ++t) { + if (*label < 0.0) { + label_len[n] = t; + break; + } + label += target_seq_blob->channels(); + } + + CHECK_LE(label_len[n], seq_len[n]) + << "The label length must be smaller or equals the sequence length!"; + } + + + // compute loss (in forward pass), and store computed results for backward + // pass + Dtype &loss = top[0]->mutable_cpu_data()[0]; + + CalculateLoss(&seq_len_, + target_seq_blob, + &label_len_, + data_blob, + &log_p_z_x_, + &l_primes_, + &y_, + preprocess_collapse_repeated_, + ctc_merge_repeated_, + &loss, + true); + + // normalize by number of parallel batches + loss /= N_; +} + +template +void CTCLossLayer::Backward_cpu(const vector*>& top, + const vector& propagate_down, + const vector*>& bottom) { + CHECK_EQ(propagate_down[0], true) + << "Required to propagate to probabilities"; + CHECK_EQ(propagate_down[1], false) + << "Cannot propagate to sequence indicators"; + CHECK_EQ(propagate_down[2], false) + << "Cannot propagate to target label sequence"; + + Blob *data_blob = bottom[0]; + + const int* seq_len = seq_len_.cpu_data(); + const double* log_p_z_x = log_p_z_x_.cpu_data(); + + // clear all diffs in data blob + caffe_set(data_blob->count(), 0, data_blob->mutable_cpu_diff()); + + // for each batch compute the gradient using the alpha and beta variables, + // y_, p_z_x and l_primes that were computed in the forward pass + for (int b = 0; b < N_; ++b) { + // We compute the derivative if needed + CalculateGradient(b, + seq_len[b], + l_primes_[b], + y_[b], + log_alpha_[b], + log_beta_[b], + log_p_z_x[b], + data_blob); + } +} + +template +void CTCLossLayer::CalculateLoss( + const LengthBlob *seq_len_blob, + const Blob* target_seq_blob, + const LengthBlob *target_seq_len_blob, + Blob* data_blob, + ProbBlob *log_p_z_x, + LabelSequences *l_primes, + vector *> *y, + bool preprocess_collapse_repeated, + bool ctc_merge_repeated, + Dtype* loss, + bool requires_backprob) const { + CHECK(seq_len_blob); + CHECK(target_seq_blob); + CHECK(target_seq_len_blob); + CHECK(data_blob); + CHECK(loss); + + const int *seq_len = seq_len_blob->cpu_data(); + + const int num_time_steps = T_; + const int batch_size = N_; + const int num_classes = C_; + + // check validity of data + CHECK_EQ(data_blob->num(), num_time_steps); + CHECK_EQ(data_blob->channels(), batch_size); + CHECK_EQ(data_blob->height(), num_classes); + CHECK_EQ(data_blob->width(), 1); + + // check validity of sequence_length arrays + int max_seq_len = seq_len[0]; + for (int b = 0; b < batch_size; ++b) { + CHECK_GE(seq_len[b], 0); + CHECK_LE(seq_len[b], num_time_steps); + max_seq_len = std::max(max_seq_len, seq_len[b]); + } + + // set loss to 0 + *loss = 0; + + // calculate the modified label sequence l' for each batch element, + // and calculate the maximum necessary allocation size. + int max_u_prime = 0; + PopulateLPrimes(preprocess_collapse_repeated, + batch_size, + num_classes, + *seq_len_blob, + *target_seq_blob, + *target_seq_len_blob, + &max_u_prime, + l_primes); + + // Compute loss and gradients + // ============================================================ + + // TODO: this can be parallelized + for (int b = 0; b < batch_size; ++b) { + const int seq_len_b = seq_len[b]; + if (seq_len_b == 0) { + continue; // zero length, no gradients or loss to be computed + } + + const SequenceBlob *l_prime = l_primes->at(b); + + const int b_T = seq_len_b - output_delay_; + + // alpha and beta reshape and access + CTCVariables* log_alpha_b = log_alpha_[b]; + CTCVariables* log_beta_b = log_beta_[b]; + log_alpha_b->Reshape(l_prime->count(), b_T, 1, 1); + log_beta_b->Reshape(l_prime->count(), b_T, 1, 1); + double *log_alpha_data = log_alpha_b->mutable_cpu_data(); + double *log_beta_data = log_beta_b->mutable_cpu_data(); + + // Work matrices, pre-allocated to the size required by this batch item + const Dtype* data_start = data_blob->cpu_data(); + Blob* y_b = y->at(b); + y_b->Reshape(seq_len_b, C_, 1, 1); + Dtype* y_start = y_b->mutable_cpu_data(); + + // compute softmax (until sequence length is sufficient) + for (int t = 0; t < seq_len_b; ++t) { + const Dtype* data = data_start + data_blob->offset(t, b); + Dtype* y_out_start = y_start + y_b->offset(t); + Dtype max_coeff = *data; + // get max coeff + for (const Dtype* c_data = data + 1; c_data != data + C_; ++c_data) { + max_coeff = std::max(max_coeff, *c_data); + } + // calc exp and its sum + Dtype sum = 0; + Dtype* y_out = y_out_start; + for (const Dtype* c_data = data; c_data != data + C_; ++c_data) { + *y_out = exp(*c_data - max_coeff); + sum += *y_out++; + } + // division by sum + for (y_out = y_out_start; y_out != y_out_start + C_; ++y_out) { + *y_out /= sum; + } + } + + // Compute forward, backward variables. + CalculateForwardVariables(l_prime, y_b, ctc_merge_repeated, log_alpha_b); + CalculateBackwardVariables(l_prime, y_b, ctc_merge_repeated, log_beta_b); + + // the lost is computed as the log(p(z|x)) between the target and the + // prediction. Do lazy evaluation of log_prob here. + double& log_p_z_x_b = log_p_z_x->mutable_cpu_data()[b]; + log_p_z_x_b = kLogZero; + const int loss_calc_t = std::max(0, std::min(loss_calculation_t_, b_T - 1)); + for (int u = 0; u < l_prime->count(); ++u) { + int offset = log_alpha_b->offset(u, loss_calc_t); + log_p_z_x_b = LogSumExp(log_p_z_x_b, + log_alpha_data[offset] + log_beta_data[offset]); + } + + // use negative loss for display + *loss += -log_p_z_x_b; + } +} + +template +void CTCLossLayer::PopulateLPrimes(bool preprocess_collapse_repeated, + int N, + int num_classes, + const LengthBlob &seq_len, + const Blob& labels, + const LengthBlob &label_len, + int *max_u_prime, + LabelSequences* l_primes) const { + CHECK_EQ(seq_len.num(), N); + CHECK_EQ(seq_len.count(), seq_len.num()); // shape must be N x 1 x 1 x 1 + CHECK_EQ(labels.channels(), N); + CHECK_EQ(label_len.num(), N); + CHECK_EQ(label_len.count(), label_len.num()); // shape must be N x 1 x 1 x 1 + CHECK(max_u_prime); + CHECK(l_primes); + + *max_u_prime = 0; // keep track of longest l' modified label sequence. + + const int* lab_len_d = label_len.cpu_data(); + const int* seq_len_d = seq_len.cpu_data(); + + for (int n = 0; n < N; ++n) { + // Assume label is in Label proto + const int label_size = lab_len_d[n]; + // pointer to the first element of the sequence + const Dtype* label = labels.cpu_data() + n; + // increment for getting label at next t + const int label_incr = labels.channels(); + CHECK_GT(label_size, 0) + << "Labels length is zero in sequence number " << n; + + const int seq_size = seq_len_d[n]; // round Dtype to int for sequence size + + // DLOG(INFO) << "label for sequence number " << n << ": " + // << imploded(extract_label_sequence(label, label_size, label_incr)); + + // target indices + std::vector l; + + bool finished_sequence = false; + const Dtype* prev_label = 0; + for (int i = 0; i < label_size; ++i) { + if (i == 0 || !preprocess_collapse_repeated || *label != *prev_label) { + int i_label = static_cast(*label + 0.5); // integer label (round) + if (i_label >= num_classes - 1) { + finished_sequence = true; + } else { + if (finished_sequence) { + // saw an invalid sequence with non-null following null labels. + LOG(FATAL) << "Saw a non-null label (index >= num_classes - 1) " + << "following a null label, sequence number " << n + << ", num_classes " << num_classes << ", labels [" + << imploded(l) << "]"; + } + l.push_back(i_label); + } + } + prev_label = label; + label += label_incr; + } + + // make sure there is enough time to output the target indices. + int time = seq_size - output_delay_; + int required_time = label_size; + for (int i = 0; i < l.size(); ++i) { + int l_i = l[i]; + CHECK_GE(l_i, 0) << "All labels must be nonnegative integers. " + << "Sequcene number " << n << ", labels " + << imploded(l); + CHECK_LT(l_i, num_classes) + << "No label may be greater than num_classes: " << num_classes + << ". At sequence number " << n << ", labels [" << imploded(l) + << "]"; + } + + CHECK_LE(required_time, time) + << "Not enough time for target transition sequence"; + + // Target indices with blanks before each index and a blank at the end. + // Length U' = 2U + 1. + // convert l to l_prime + GetLPrimeIndices(l, l_primes->at(n)); + *max_u_prime = std::max(*max_u_prime, l_primes->at(n)->count()); + } +} + +template +void CTCLossLayer::GetLPrimeIndices(const std::vector& l, + SequenceBlob *l_prime) const { + l_prime->Reshape(2 * l.size() + 1, 1, 1, 1); + int* l_prime_d = l_prime->mutable_cpu_data(); + + for (int i = 0; i < l.size(); ++i) { + int label = l[i]; + *l_prime_d++ = blank_index_; + *l_prime_d++ = label; + } + + *l_prime_d = blank_index_; +} + +template +void CTCLossLayer::CalculateForwardVariables( + const SequenceBlob* l_prime, + const Blob* y, + bool ctc_merge_repeated, + CTCVariables* log_alpha) const { + // Note that the order of log beta is N x T instead of T x N + const int U = l_prime->count(); + const int T = log_alpha->channels(); + CHECK_EQ(U, log_alpha->num()); + + // Data pointers, fill alpha with kLogZero + double* log_alpha_d = log_alpha->mutable_cpu_data(); + caffe_set(log_alpha->count(), kLogZero, log_alpha->mutable_cpu_data()); + const Dtype* y_d = y->cpu_data(); + const int* l_prime_d = l_prime->cpu_data(); + + // Initialize alpha values in Graves Eq (7.5) and Eq (7.6). + log_alpha_d[log_alpha->offset(0, 0)] + = log(y_d[y->offset(output_delay_, blank_index_)]); + // Below, l_prime[1] == label[0] + const int label_0 = (U > 1) ? l_prime_d[1] : blank_index_; + log_alpha_d[log_alpha->offset(1, 0)] + = log(y_d[y->offset(output_delay_, label_0)]); + + for (int t = 1; t < T; ++t) { + // If there is not enough time to output the remaining labels or + // some labels have been skipped, then let log_alpha(u, t) continue to + // be kLogZero. + for (int u = std::max(0, U - (2 * (T - t))); + u < std::min(U, 2 * (t + 1)); + ++u) { + // Begin Graves Eq (7.9) + // Add in the u, t - 1 term. + double sum_log_alpha = kLogZero; + if (ctc_merge_repeated || l_prime_d[u] == blank_index_) { + sum_log_alpha = log_alpha_d[log_alpha->offset(u, t - 1)]; + } + + // Add in the u - 1, t - 1 term. + if (u > 0) { + sum_log_alpha + = LogSumExp(sum_log_alpha, + log_alpha_d[log_alpha->offset(u - 1, t - 1)]); + } + + // Add in the u - 2, t - 1 term if l_prime(u) != blank or l_prime(u-2). + if (u > 1) { + const bool matching_labels_merge + = ctc_merge_repeated && (l_prime_d[u] == l_prime_d[u - 2]); + if (l_prime_d[u] != blank_index_ && !matching_labels_merge) { + sum_log_alpha + = LogSumExp(sum_log_alpha, + log_alpha_d[log_alpha->offset(u - 2, t - 1)]); + } + } + // Multiply the summed alphas with the activation log probability. + const Dtype y_v = y_d[y->offset(output_delay_ + t, l_prime_d[u])]; + log_alpha_d[log_alpha->offset(u, t)] = log(y_v) + sum_log_alpha; + } // End Graves Eq (7.9) + } +} + +template +void CTCLossLayer::CalculateBackwardVariables( + const SequenceBlob *l_prime, + const Blob* y, + bool ctc_merge_repeated, + CTCVariables *log_beta) const { + // Note that the order of log beta is N x T instead of T x N + const int U = l_prime->count(); + const int T = log_beta->channels(); + CHECK_EQ(U, log_beta->num()); + + // Data pointers, fill beta with kLogZero + double *log_beta_d = log_beta->mutable_cpu_data(); + caffe_set(log_beta->count(), kLogZero, log_beta_d); + const Dtype *y_d = y->cpu_data(); + + const int* l_prime_d = l_prime->cpu_data(); + + // Initial beta blaues in Graves Eq (7.13): log of probability 1. + for (int u = U - 2; u < U; ++u) { + log_beta_d[log_beta->offset(u, T - 1)] = 0; + } + + for (int t = T - 1 - 1; t >= 0; --t) { + // If ther is not enough time to output the remaining labels or + // some labels have been skipped, then let log_beta[u, t] continue to + // be kLogZero. + for (int u = std::max(0, U - (2 * (T - t))); + u < std::min(U, 2 * (t + 1)); + ++u) { + double &log_beta_ut = log_beta_d[log_beta->offset(u, t)]; + + // Begin Graves Eq (7.15) + // Add in the u, t + 1 term. + if (ctc_merge_repeated || l_prime_d[u] == blank_index_) { + const double &log_beta_ut1 = log_beta_d[log_beta->offset(u, t + 1)]; + const double &y_u0 + = y_d[y->offset(output_delay_ + t + 1, l_prime_d[u])]; + DCHECK_GE(y_u0, 0) + << "Output of the net must be a probability distribution."; + DCHECK_LE(y_u0, 1) + << "Output of the net must be a probability distribution."; + log_beta_ut = LogSumExp(log_beta_ut, log_beta_ut1 + log(y_u0)); + } + + // Add in the u + 1, t + 1 term. + if (u + 1 < U) { + const double &log_beta_u1t1 + = log_beta_d[log_beta->offset(u + 1, t + 1)]; + const double &y_u1 + = y_d[y->offset(output_delay_ + t + 1, l_prime_d[u + 1])]; + DCHECK_GE(y_u1, 0) + << "Output of the net must be a probability distribution."; + DCHECK_LE(y_u1, 1) + << "Output of the net must be a probability distribution."; + log_beta_ut = LogSumExp(log_beta_ut, log_beta_u1t1 + log(y_u1)); + } + + // Add in the u + 2, t + 1 term if l_prime[u] != blank or l_prime[u+2] + if (u + 2 < U) { + const bool matching_labels_merge = + ctc_merge_repeated && (l_prime_d[u] == l_prime_d[u + 2]); + if (l_prime_d[u] != blank_index_ && !matching_labels_merge) { + const double &log_beta_u2t1 + = log_beta_d[log_beta->offset(u + 2, t + 1)]; + const double &y_u2 + = y_d[y->offset(output_delay_ + t + 1, l_prime_d[u + 2])]; + DCHECK_GE(y_u2, 0) + << "Output of the net must be a probability distribution."; + DCHECK_LE(y_u2, 1) + << "Output of the net must be a probability distribution."; + + // Add in u + 2 term. + log_beta_ut = LogSumExp(log_beta_ut, log_beta_u2t1 + log(y_u2)); + } + } + } // End Graves Eq. (7.15) + } +} + +template +void CTCLossLayer::CalculateGradient( + int b, + int seq_length, + const SequenceBlob *l_prime, + const Blob *y_d_blob, + const CTCVariables* log_alpha, + const CTCVariables* log_beta, + double log_p_z_x, + Blob *y) const { + const int L = C_; + const int T = seq_length; + CHECK_LE(seq_length, y->num()); + CHECK_EQ(L, y->height()); + const int U = l_prime->count(); + + const double* log_alpha_d = log_alpha->cpu_data(); + const double* log_beta_d = log_beta->cpu_data(); + const Dtype* y_d = y_d_blob->cpu_data(); + Dtype* y_diff_d = y->mutable_cpu_diff(); + const int* l_prime_d = l_prime->cpu_data(); + + DCHECK_EQ(y_diff_d[y->offset(0, b, 0)], static_cast(0)); + + // It is possible that no valid path is found if the activations for the + // targets are zero. + if (log_p_z_x == kLogZero) { + LOG(WARNING) << "No valid path found."; + // dy is then y + for (int t = 0; t < T - output_delay_; ++t) { + for (int l = 0; l < L; ++l) { + y_diff_d[y->offset(output_delay_ + t, b, l)] + = y_d[y_d_blob->offset(output_delay_ + t, l)]; + } + } + return; + } + + + for (int t = 0; t < T - output_delay_; ++t) { + vector prob_sum(L, kLogZero); + + for (int u = 0; u < U; ++u) { + const int l = l_prime_d[u]; + prob_sum[l] + = LogSumExp(prob_sum[l], + log_alpha_d[log_alpha->offset(u, t)] + + log_beta_d[log_beta->offset(u, t)]); + } + + for (int l = 0; l < L; ++l) { + const double negative_term = exp(prob_sum[l] - log_p_z_x); + y_diff_d[y->offset(output_delay_ + t, b, l)] + = (y_d[y_d_blob->offset(output_delay_ + t, l)] - negative_term); + } + } +} + +INSTANTIATE_CLASS(CTCLossLayer); +REGISTER_LAYER_CLASS(CTCLoss); + +} // namespace caffe diff --git a/src/caffe/layers/reverse_layer.cpp b/src/caffe/layers/reverse_layer.cpp new file mode 100644 index 00000000000..fc0121c7b98 --- /dev/null +++ b/src/caffe/layers/reverse_layer.cpp @@ -0,0 +1,81 @@ +#include "caffe/layers/reverse_layer.hpp" + +#include + +namespace caffe { + +template +ReverseLayer::ReverseLayer(const LayerParameter& param) + : NeuronLayer(param) + , axis_(param.reverse_param().axis()) { + CHECK_GE(axis_, 0); +} + +template +void ReverseLayer::LayerSetUp(const vector*>& bottom, + const vector*>& top) { + NeuronLayer::LayerSetUp(bottom, top); + CHECK_NE(top[0], bottom[0]) << this->type() << " Layer does not " + "allow in-place computation."; + + CHECK_LT(axis_, bottom[0]->num_axes()) + << "Axis must be less than the number of axis for reversing"; +} + +template +void ReverseLayer::Forward_cpu( + const vector*>& bottom, const vector*>& top) { + const Dtype* src = bottom[0]->cpu_data(); + + const int count = top[0]->count(); + const int axis_count = top[0]->count(axis_); + const int copy_amount + = (axis_ + 1 == top[0]->num_axes()) ? 1 : top[0]->count(axis_ + 1); + const int num_fix = (axis_ > 0) ? count / axis_count : 1; + const int sub_iter_max = top[0]->shape(axis_); + + for (int fix = 0; fix < num_fix; ++fix) { + Dtype* target = top[0]->mutable_cpu_data() + + (fix + 1) * copy_amount * sub_iter_max - copy_amount; + for (int i = 0; i < sub_iter_max; ++i) { + caffe_copy(copy_amount, src, target); + src += copy_amount; // normal order + target -= copy_amount; + } + } +} + +template +void ReverseLayer::Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) { + if (!propagate_down[0]) { return; } + + Dtype* target = bottom[0]->mutable_cpu_diff(); + + const int count = top[0]->count(); + const int axis_count = top[0]->count(axis_); + const int copy_amount = + (axis_ + 1 == top[0]->num_axes()) ? 1 : top[0]->count(axis_ + 1); + const int num_fix = (axis_ > 0) ? count / axis_count : 1; + const int sub_iter_max = top[0]->shape(axis_); + + for (int fix = 0; fix < num_fix; ++fix) { + const Dtype* src + = top[0]->cpu_diff() + (fix + 1) * copy_amount * sub_iter_max + - copy_amount; + for (int i = 0; i < sub_iter_max; ++i) { + caffe_copy(copy_amount, src, target); + target += copy_amount; // normal order + src -= copy_amount; // reverse order + } + } +} + +#ifdef CPU_ONLY +STUB_GPU(ReverseLayer); +#endif + +INSTANTIATE_CLASS(ReverseLayer); +REGISTER_LAYER_CLASS(Reverse); + +} // namespace caffe diff --git a/src/caffe/layers/reverse_layer.cu b/src/caffe/layers/reverse_layer.cu new file mode 100644 index 00000000000..9b12cfbcf84 --- /dev/null +++ b/src/caffe/layers/reverse_layer.cu @@ -0,0 +1,59 @@ +#include + +#include "caffe/layers/reverse_layer.hpp" +#include "caffe/util/math_functions.hpp" + +namespace caffe { + +template +void ReverseLayer::Forward_gpu(const vector*>& bottom, + const vector*>& top) { + const Dtype* src = bottom[0]->gpu_data(); + + const int count = top[0]->count(); + const int axis_count = top[0]->count(axis_); + const int copy_amount + = (axis_ + 1 == top[0]->num_axes()) ? 1 : top[0]->count(axis_ + 1); + const int num_fix = (axis_ > 0) ? count / axis_count : 1; + const int sub_iter_max = top[0]->shape(axis_); + + for (int fix = 0; fix < num_fix; ++fix) { + Dtype* target = top[0]->mutable_gpu_data() + + (fix + 1) * copy_amount * sub_iter_max - copy_amount; + for (int i = 0; i < sub_iter_max; ++i) { + caffe_copy(copy_amount, src, target); + src += copy_amount; // normal order + target -= copy_amount; // reverse order + } + } +} + +template +void ReverseLayer::Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) { + if (!propagate_down[0]) { return; } + + Dtype* target = bottom[0]->mutable_gpu_diff(); + + const int count = top[0]->count(); + const int axis_count = top[0]->count(axis_); + const int copy_amount + = (axis_ + 1 == top[0]->num_axes()) ? 1 : top[0]->count(axis_ + 1); + const int num_fix = (axis_ > 0) ? count / axis_count : 1; + const int sub_iter_max = top[0]->shape(axis_); + + for (int fix = 0; fix < num_fix; ++fix) { + const Dtype* src = top[0]->gpu_diff() + + (fix + 1) * copy_amount * sub_iter_max - copy_amount; + for (int i = 0; i < sub_iter_max; ++i) { + caffe_copy(copy_amount, src, target); + target += copy_amount; // normal order + src -= copy_amount; // reverse order + } + } +} + +INSTANTIATE_LAYER_GPU_FUNCS(ReverseLayer); + + +} // namespace caffe diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index 6940a705eb6..f2d3aa5f93b 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -306,7 +306,7 @@ message ParamSpec { // NOTE // Update the next available ID when you add a new LayerParameter field. // -// LayerParameter next available layer-specific ID: 147 (last added: recurrent_param) +// LayerParameter next available layer-specific ID: 150 (last added: ctc_decoder_param) message LayerParameter { optional string name = 1; // the layer name optional string type = 2; // the layer type @@ -366,6 +366,8 @@ message LayerParameter { optional ContrastiveLossParameter contrastive_loss_param = 105; optional ConvolutionParameter convolution_param = 106; optional CropParameter crop_param = 144; + optional CTCDecoderParameter ctc_decoder_param = 149; + optional CTCLossParameter ctc_loss_param = 148; optional DataParameter data_param = 107; optional DropoutParameter dropout_param = 108; optional DummyDataParameter dummy_data_param = 109; @@ -394,6 +396,7 @@ message LayerParameter { optional ReductionParameter reduction_param = 136; optional ReLUParameter relu_param = 123; optional ReshapeParameter reshape_param = 133; + optional ReverseParameter reverse_param = 147; optional ScaleParameter scale_param = 142; optional SigmoidParameter sigmoid_param = 124; optional SoftmaxParameter softmax_param = 125; @@ -624,6 +627,47 @@ message CropParameter { repeated uint32 offset = 2; } +message CTCDecoderParameter { + // The index of the blank index in the labels. A negative (default) + // value will use the last index + optional int32 blank_index = 1 [default = -1]; + + // Collapse the repeated labels during the ctc calculation + // e.g. collapse [0bbb11bb11bb0b2] to [01102] instead of [0111102], + // where b means blank label. + // The default behaviour is to merge repeated labels. + // Note: blank labels will be removed in any case. + optional bool ctc_merge_repeated = 2 [default = true]; +} + +message CTCLossParameter { + // Adds delayed output to the CTC loss calculation (untested!) + optional int32 output_delay = 1 [default = 0]; + + // The index of the blank index in the labels. A negative (default) + // value will use the last index + optional int32 blank_index = 2 [default = -1]; + + // Collapse repeating labels of the target sequence before calculating + // the loss and the gradients (e.g. collapse [01102] to [0102]) + // The default behaviour is to keep repeated labels. Elsewise the + // network will not learn to predict repetitions. + optional bool preprocess_collapse_repeated = 3 [default = false]; + + // Collapse the repeated labels during the ctc calculation + // e.g collapse [0bbb11bb11bb0b2] to [01102] instead of [0111102], + // where b means blank label. + // The default behaviour is to merge repeated labels. + // Note: blank labels will be removed in any case. + optional bool ctc_merge_repeated = 4 [default = true]; + + /// This parameter is for test cases only! + /// The time for which to calculate the loss (see Graves Eq. (7.27) ) + /// Note that the result must be the same for each 0 <= t < T + /// Therefore you can chose an arbitrary value, default 0 + optional int32 loss_calculation_t = 5 [default = 0]; +} + message DataParameter { enum DB { LEVELDB = 0; @@ -1057,6 +1101,21 @@ message ReshapeParameter { optional int32 num_axes = 3 [default = -1]; } +message ReverseParameter { + // axis controls the data axis which shall be inverted. + // The layout of the content will not be inverted + // + // The default axis is 0 that means: + // data_previous[n] == data_afterwards[N - n -1] + // where N is the shape of axis(n) + // + // Usually this layer will be used with recurrent layers to invert the + // time axis which is axis 0 + // This layer will therefore swap the order in time but not the + // order of the actual data. + optional int32 axis = 1 [default = 0]; +} + message ScaleParameter { // The first axis of bottom[0] (the first input Blob) along which to apply // bottom[1] (the second input Blob). May be negative to index from the end diff --git a/src/caffe/test/test_ctc_decoder_layer.cpp b/src/caffe/test/test_ctc_decoder_layer.cpp new file mode 100644 index 00000000000..61384a275f7 --- /dev/null +++ b/src/caffe/test/test_ctc_decoder_layer.cpp @@ -0,0 +1,220 @@ +#include +#include +#include +#include + +#include "gtest/gtest.h" + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/filler.hpp" +#include "caffe/layers/ctc_decoder_layer.hpp" + +#include "caffe/test/test_caffe_main.hpp" + +namespace caffe { + +template +class CTCDecoderLayerTest : public MultiDeviceTest { + typedef typename TypeParam::Dtype Dtype; + + protected: + CTCDecoderLayerTest() + : T_(4), + N_(2), + num_labels_(3), + blob_bottom_data_(new Blob(T_, N_, num_labels_, 1)), + blob_bottom_seq_ind_(new Blob(T_, N_, 1, 1)), + blob_bottom_target_seq_(new Blob(T_, N_, 1, 1)), + blob_top_scores_(new Blob()), + blob_top_accuracy_(new Blob()) { + // Add blobs to the correct bottom/top lists + blob_bottom_vec_.push_back(blob_bottom_data_); + blob_bottom_vec_.push_back(blob_bottom_seq_ind_); + blob_top_vec_.push_back(blob_top_scores_); + } + + virtual ~CTCDecoderLayerTest() { + delete blob_bottom_data_; + delete blob_bottom_seq_ind_; + delete blob_top_scores_; + } + + void Reshape(int t, int n, int c) { + T_ = t; + N_ = n; + num_labels_ = c; + blob_bottom_data_->Reshape(T_, N_, num_labels_, 1); + blob_bottom_seq_ind_->Reshape(T_, N_, 1, 1); + blob_bottom_target_seq_->Reshape(T_, N_, 1, 1); + } + + void AddAccuracyOutput() { + blob_bottom_vec_.push_back(blob_bottom_target_seq_); + blob_top_vec_.push_back(blob_top_accuracy_); + } + + template + vector raw(const T data[], int size) { + vector o(size); + caffe_copy(size, data, o.data()); + return o; + } + + vector log(const vector &in) { + vector o(in.size()); + for (size_t i = 0; i < in.size(); ++i) { + o[i] = std::log(in[i]); + } + return o; + } + + Dtype sum(const vector &in) { + return std::accumulate(in.begin(), in.end(), static_cast(0)); + } + + vector neg(const vector &in) { + vector o(in.size()); + for (size_t i = 0; i < in.size(); ++i) { + o[i] = -in[i]; + } + return o; + } + + void TestGreedyDecoder(bool check_accuracy = true) { + if (check_accuracy) { + AddAccuracyOutput(); + } + + // Test two batch entries - best path decoder. + const int max_time_steps = 6; + const int depth = 4; + + // const int seq_len_0 = 4; + const Dtype input_prob_matrix_0[max_time_steps * depth] = + {1.0, 0.0, 0.0, 0.0, // t=0 + 0.0, 0.0, 0.4, 0.6, // t=1 + 0.0, 0.0, 0.4, 0.6, // t=2 + 0.0, 0.9, 0.1, 0.0, // t=3 + 0.0, 0.0, 0.0, 0.0, // t=4 (ignored) + 0.0, 0.0, 0.0, 0.0 // t=5 (ignored) + }; + + const vector input_log_prob_matrix_0( + log(raw(input_prob_matrix_0, max_time_steps * depth))); + const Dtype prob_truth_0[depth] = {1.0, 0.6, 0.6, 0.9}; + const int label_len_0 = 2; + const int correct_sequence_0[label_len_0] = {0, 1}; + + // const int seq_len_1 = 5; + const Dtype input_prob_matrix_1[max_time_steps * depth] = + {0.1, 0.9, 0.0, 0.0, // t=0 + 0.0, 0.9, 0.1, 0.0, // t=1 + 0.0, 0.0, 0.1, 0.9, // t=2 + 0.0, 0.9, 0.1, 0.1, // t=3 + 0.9, 0.1, 0.0, 0.0, // t=4 + 0.0, 0.0, 0.0, 0.0 // t=5 (ignored) + }; + + const vector input_log_prob_matrix_1( + log(raw(input_prob_matrix_1, max_time_steps * depth))); + const Dtype prob_truth_1[depth] = {0.9, 0.9, 0.9, 0.9}; + const int label_len_1 = 3; + const int correct_sequence_1[label_len_1] = {1, 1, 0}; + + const Dtype log_prob_truth[2] = { + sum(neg(log(raw(prob_truth_0, depth)))), + sum(neg(log(raw(prob_truth_1, depth)))) + }; + + Reshape(max_time_steps, 2, depth); + + // copy data + Dtype* data = blob_bottom_data_->mutable_cpu_data(); + for (int t = 0; t < max_time_steps; ++t) { + for (int c = 0; c < depth; ++c) { + data[blob_bottom_data_->offset(t, 0, c)] + = input_log_prob_matrix_0[t * depth + c]; + data[blob_bottom_data_->offset(t, 1, c)] + = input_log_prob_matrix_1[t * depth + c]; + } + } + + // set sequence indicators + Dtype* seq_ind = blob_bottom_seq_ind_->mutable_cpu_data(); + caffe_set(blob_bottom_seq_ind_->count(), static_cast(1), seq_ind); + // sequence 1: + seq_ind[0 * 2 + 0] = seq_ind[4 * 2 + 0] = seq_ind[5 * 2 + 0] = 0; + // sequence 2; + seq_ind[0 * 2 + 1] = seq_ind[5 * 2 + 1] = 0; + + // target sequences + if (check_accuracy) { + Dtype* target_seq = blob_bottom_target_seq_->mutable_cpu_data(); + caffe_set(blob_bottom_target_seq_->count(), + static_cast(-1), target_seq); + for (int i = 0; i < label_len_0; ++i) { + target_seq[blob_bottom_target_seq_->offset(i, 0)] + = correct_sequence_0[i]; + } + + for (int i = 0; i < label_len_1; ++i) { + target_seq[blob_bottom_target_seq_->offset(i, 1)] + = correct_sequence_1[i]; + } + } + + LayerParameter layer_param; + CTCGreedyDecoderLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + + layer.Forward(blob_bottom_vec_, blob_top_vec_); + + CHECK_EQ(layer.OutputSequences().size(), 2); + + const Dtype* scores = blob_top_scores_->cpu_data(); + + // check n=0 + EXPECT_EQ(label_len_0, layer.OutputSequences()[0].size()); + for (int i = 0; i < label_len_0; ++i) { + EXPECT_EQ(correct_sequence_0[i], layer.OutputSequences()[0][i]); + } + EXPECT_FLOAT_EQ(scores[0], log_prob_truth[0]); + + // check n=1 + EXPECT_EQ(label_len_1, layer.OutputSequences()[1].size()); + for (int i = 0; i < label_len_1; ++i) { + EXPECT_EQ(correct_sequence_1[i], layer.OutputSequences()[1][i]); + } + EXPECT_FLOAT_EQ(scores[0], log_prob_truth[0]); + + if (check_accuracy) { + const Dtype *acc = blob_top_accuracy_->cpu_data(); + // output must have a edit distance of 0, + // the accuracy must therefore be 100% + EXPECT_FLOAT_EQ(acc[0], 1); + } + } + + int T_; + int N_; + int num_labels_; + Blob* const blob_bottom_data_; + Blob* const blob_bottom_seq_ind_; + Blob* const blob_bottom_target_seq_; + Blob* const blob_top_scores_; + Blob* const blob_top_accuracy_; + + vector*> blob_bottom_vec_; + vector*> blob_top_vec_; +}; + +TYPED_TEST_CASE(CTCDecoderLayerTest, TestDtypesAndDevices); + +TYPED_TEST(CTCDecoderLayerTest, TestGreedyDecoder) { + this->TestGreedyDecoder(true); // with acc test + this->TestGreedyDecoder(false); // without acc test +} + + +} // namespace caffe diff --git a/src/caffe/test/test_ctc_loss_layer.cpp b/src/caffe/test/test_ctc_loss_layer.cpp new file mode 100644 index 00000000000..051e83cbcd5 --- /dev/null +++ b/src/caffe/test/test_ctc_loss_layer.cpp @@ -0,0 +1,634 @@ +#include +#include +#include + +#include "gtest/gtest.h" + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/filler.hpp" +#include "caffe/layers/ctc_loss_layer.hpp" + +#include "caffe/test/test_caffe_main.hpp" +#include "caffe/test/test_gradient_check_util.hpp" + +namespace caffe { + +template +class CTCLossLayerTest : public MultiDeviceTest { + typedef typename TypeParam::Dtype Dtype; + + protected: + CTCLossLayerTest() + : T_(4), + N_(2), + num_labels_(3), + blob_bottom_data_(new Blob(T_, N_, num_labels_, 1)), + blob_bottom_label_(new Blob(T_, N_, 1, 1)), + blob_bottom_seq_ind_(new Blob(T_, N_, 1, 1)), + blob_top_loss_(new Blob()) { + // Add blobs to the correct bottom/top lists + blob_bottom_vec_.push_back(blob_bottom_data_); + blob_bottom_vec_.push_back(blob_bottom_seq_ind_); + blob_bottom_vec_.push_back(blob_bottom_label_); + blob_top_vec_.push_back(blob_top_loss_); + } + + virtual ~CTCLossLayerTest() { + delete blob_bottom_data_; + delete blob_bottom_label_; + delete blob_bottom_seq_ind_; + delete blob_top_loss_; + } + + void Reshape(int t, int n, int c) { + T_ = t; + N_ = n; + num_labels_ = c; + blob_bottom_data_->Reshape(T_, N_, num_labels_, 1); + blob_bottom_label_->Reshape(T_, N_, 1, 1); + blob_bottom_seq_ind_->Reshape(T_, N_, 1, 1); + } + + void InitConstantCorrect() { + // fill the values with constant values + // The prediction is 100% correct so that the loss + // must be 0 + + FillerParameter filler_c1_param; + filler_c1_param.set_value(1); + ConstantFiller c1_filler(filler_c1_param); + c1_filler.Fill(blob_bottom_seq_ind_); + + FillerParameter filler_c0_param; + filler_c0_param.set_value(0); + ConstantFiller c0_filler(filler_c0_param); + c0_filler.Fill(blob_bottom_data_); + + FillerParameter filler_cn1_param; + filler_cn1_param.set_value(-1); + ConstantFiller cn1_filler(filler_cn1_param); + cn1_filler.Fill(blob_bottom_label_); + + // sequence start (full size) + for (int b = 0; b < N_; ++b) { + blob_bottom_seq_ind_->mutable_cpu_data()[b] = 0; + } + + const Dtype one = std::numeric_limits::max(); + + // set label + Dtype *label = blob_bottom_label_->mutable_cpu_data(); + label[blob_bottom_label_->offset(0, 0)] = 0; + label[blob_bottom_label_->offset(1, 0)] = 1; + + label[blob_bottom_label_->offset(0, 1)] = 0; + label[blob_bottom_label_->offset(1, 1)] = 1; + + // set probabilities + // (100% correct, but second with one additional timestep) + Dtype *data = blob_bottom_data_->mutable_cpu_data(); + data[blob_bottom_data_->offset(0, 0, 0)] = one; + data[blob_bottom_data_->offset(1, 0, 1)] = one; + data[blob_bottom_data_->offset(2, 0, num_labels_ - 1)] = one; + data[blob_bottom_data_->offset(3, 0, num_labels_ - 1)] = one; + + data[blob_bottom_data_->offset(0, 1, 0)] = one; + data[blob_bottom_data_->offset(1, 1, num_labels_ - 1)] = one; + data[blob_bottom_data_->offset(2, 1, 1)] = one; + data[blob_bottom_data_->offset(3, 1, num_labels_ - 1)] = one; + + + // Check data for consistency + CheckData(); + } + + void InitConstantWrong() { + // fill the values with constant values + // The prediction is wrong in one label + + FillerParameter filler_c1_param; + filler_c1_param.set_value(1); + ConstantFiller c1_filler(filler_c1_param); + c1_filler.Fill(blob_bottom_seq_ind_); + + FillerParameter filler_c0_param; + filler_c0_param.set_value(0); + ConstantFiller c0_filler(filler_c0_param); + c0_filler.Fill(blob_bottom_data_); + + FillerParameter filler_cn1_param; + filler_cn1_param.set_value(-1); + ConstantFiller cn1_filler(filler_cn1_param); + cn1_filler.Fill(blob_bottom_label_); + + // sequence start (full size) + for (int b = 0; b < N_; ++b) { + blob_bottom_seq_ind_->mutable_cpu_data()[b] = 0; + } + + const Dtype one = std::numeric_limits::max(); + + // set label + Dtype *label = blob_bottom_label_->mutable_cpu_data(); + label[blob_bottom_label_->offset(0, 0)] = 0; + label[blob_bottom_label_->offset(1, 0)] = 1; + + label[blob_bottom_label_->offset(0, 1)] = 0; + label[blob_bottom_label_->offset(1, 1)] = 1; + + // set probabilities + // (flipped predictions compared to correct, last row 1<->0) + Dtype *data = blob_bottom_data_->mutable_cpu_data(); + data[blob_bottom_data_->offset(0, 0, num_labels_ - 1)] = one; + data[blob_bottom_data_->offset(1, 0, num_labels_ - 1)] = one; + data[blob_bottom_data_->offset(2, 0, num_labels_ - 1)] = one; + data[blob_bottom_data_->offset(3, 0, num_labels_ - 1)] = one; + + data[blob_bottom_data_->offset(0, 1, num_labels_ - 1)] = one; + data[blob_bottom_data_->offset(1, 1, num_labels_ - 1)] = one; + data[blob_bottom_data_->offset(2, 1, num_labels_ - 1)] = one; + data[blob_bottom_data_->offset(3, 1, num_labels_ - 1)] = one; + + + // Check data for consistency + CheckData(); + } + + void InitConstantEqual(int sequence_length, int T = 5, int C = 10) { + CHECK_LE(sequence_length, T); + // fill the values with constant values + // The prediction is wrong is equal in each position (everything is 0) + + Reshape(T, 1, C); + + FillerParameter filler_c1_param; + filler_c1_param.set_value(1); + ConstantFiller c1_filler(filler_c1_param); + c1_filler.Fill(blob_bottom_seq_ind_); + + FillerParameter filler_c0_param; + filler_c0_param.set_value(0); + ConstantFiller c0_filler(filler_c0_param); + c0_filler.Fill(blob_bottom_data_); + + FillerParameter filler_cn1_param; + filler_cn1_param.set_value(-1); + ConstantFiller cn1_filler(filler_cn1_param); + cn1_filler.Fill(blob_bottom_label_); + + // sequence start (full size) + for (int b = 0; b < N_; ++b) { + blob_bottom_seq_ind_->mutable_cpu_data()[b] = 0; + } + + const Dtype one = std::numeric_limits::max(); + + // set label + Dtype *label = blob_bottom_label_->mutable_cpu_data(); + int label_to_set = 0; + for (int t = 0; t < sequence_length; ++t) { + label[blob_bottom_label_->offset(t, 0)] = label_to_set; + label_to_set = (label_to_set + 1) % (num_labels_ - 1); + } + + // Check data for consistency + CheckData(); + } + + void InitRandom() { + // This will create random data and random labels for all blobs. + + // increase T_, N_, num_labels_ + Reshape(41, 29, 37); + + // Fill the data + // ============================================================= + + FillerParameter gfp; + gfp.set_std(1); + GaussianFiller gf(gfp); + // random data + gf.Fill(blob_bottom_data_); + + // Fill the sequence indicators + // ============================================================== + + // arbitrary sequence length, at least 1 + vector seq_lengths(N_); + caffe_rng_uniform(N_, 1.0, T_ - 0.001, seq_lengths.data()); + + // 1. Fill with 1 + FillerParameter filler_c1_param; + filler_c1_param.set_value(1); + ConstantFiller c1_filler(filler_c1_param); + c1_filler.Fill(blob_bottom_seq_ind_); + + // 2. sequence start (always at beginning = 0) + // 3. sequence length (all further = 0) + Dtype* seq_data = blob_bottom_seq_ind_->mutable_cpu_data(); + for (int b = 0; b < N_; ++b) { + seq_data[b] = 0; + for (int t = static_cast(seq_lengths[b]); t < T_; ++t) { + seq_data[blob_bottom_seq_ind_->offset(t, b)] = 0; + } + } + + // Fill the labels + // ============================================================== + + // arbitrary labels + FillerParameter lufp; + lufp.set_min(0); + lufp.set_max(num_labels_ - 1.0001); // note that last label is blank + UniformFiller luf(lufp); + luf.Fill(blob_bottom_label_); + + // loop through all elements and set to integer values (e.g. 4.25 to 4) + Dtype* label = blob_bottom_label_->mutable_cpu_data(); + const Dtype* end_label = label + blob_bottom_label_->count(); + for (; label < end_label; ++label) { + *label = static_cast(static_cast(*label)); + } + + // loop though all elements and set the label length to + // 1 <= label length <= sequence length + label = blob_bottom_label_->mutable_cpu_data(); + for (int n = 0; n < N_; ++n) { + const int seq_len = static_cast(seq_lengths[n]); + const int label_len = caffe_rng_rand() % (seq_len) + 1; + CHECK_LE(label_len, seq_len); + CHECK_GE(label_len, 0); + + for (int t = label_len; t < T_; ++t) { + label[blob_bottom_label_->offset(t, n)] = -1; + } + } + + // Check data for consistency + // ============================================================== + CheckData(); + } + + void CheckData() { + // check label_length <= sequence length + const Dtype* label_data = blob_bottom_label_->cpu_data(); + const Dtype* seq_data = blob_bottom_seq_ind_->cpu_data(); + for (int n = 0; n < N_; ++n) { + Dtype seq_len = -1; + Dtype lab_len = -1; + for (int t = 0; t < T_; ++t) { + const Dtype lab = label_data[blob_bottom_label_->offset(t, n)]; + + // expect all following labels to be negative (not filled) + if (lab_len >= 0.0) { + EXPECT_LT(lab, 0.0); + } else { + // if first not filled label appears we know the label length + if (lab < 0.0) { + lab_len = t; + } + } + } + + for (int t = 1; t < T_; ++t) { + const Dtype seq + = seq_data[blob_bottom_seq_ind_->offset(t, n)]; + + // expect all following sequence indicators to be 0.0, in our test case + if (seq_len >= 0.0) { + EXPECT_DOUBLE_EQ(seq, 0.0); + } else { + // if another 0 appears we know the sequence length + if (seq == 0.0) { + seq_len = t; + } + } + } + + // check if no end indicator was found, therefore the complete T_ is used + if (lab_len < 0.0) { + lab_len = T_; + } + if (seq_len < 0.0) { + seq_len = T_; + } + + EXPECT_GE(seq_len, 0); + EXPECT_GE(lab_len, 0); + EXPECT_LE(lab_len, seq_len); + } + } + + void TestForward() { + InitConstantCorrect(); // constant and correct data (loss must be 0) + + LayerParameter layer_param; + CTCLossLayer layer_1(layer_param); + layer_1.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + const Dtype loss_weight_1 = + layer_1.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + + const Dtype loss = blob_top_loss_->cpu_data()[0]; + EXPECT_FLOAT_EQ(loss, 0.0); + + EXPECT_EQ(layer_1.SequenceLength().num(), N_); + EXPECT_EQ(layer_1.SequenceLength().count(), N_); + EXPECT_EQ(layer_1.SequenceLength().cpu_data()[0], 4); + EXPECT_EQ(layer_1.SequenceLength().cpu_data()[1], 4); + + EXPECT_EQ(layer_1.LabelLength().num(), N_); + EXPECT_EQ(layer_1.LabelLength().count(), N_); + EXPECT_EQ(layer_1.LabelLength().cpu_data()[0], 2); + EXPECT_EQ(layer_1.LabelLength().cpu_data()[1], 2); + + // check loss for all other t + // (to check Graves Eq. (7.27) that holds for all t) + for (int t = 1; t < T_; ++t) { + layer_1.SetLossCalculationT(t); + layer_1.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + + const Dtype loss_t = blob_top_loss_->cpu_data()[0]; + EXPECT_FLOAT_EQ(loss, loss_t); + } + + // The gradients (deltas on all variables) must be 0 in this special case + EXPECT_FLOAT_EQ(blob_bottom_data_->asum_diff(), 0); + } + + void TestForwardWrong() { + InitConstantWrong(); + LayerParameter layer_param; + + CTCLossLayer layer_1(layer_param); + layer_1.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + const Dtype loss_weight_1 = + layer_1.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + + const Dtype loss = blob_top_loss_->cpu_data()[0]; + EXPECT_FLOAT_EQ(loss, std::numeric_limits::max()); + + EXPECT_EQ(layer_1.SequenceLength().num(), N_); + EXPECT_EQ(layer_1.SequenceLength().count(), N_); + EXPECT_EQ(layer_1.SequenceLength().cpu_data()[0], 4); + EXPECT_EQ(layer_1.SequenceLength().cpu_data()[1], 4); + + EXPECT_EQ(layer_1.LabelLength().num(), N_); + EXPECT_EQ(layer_1.LabelLength().count(), N_); + EXPECT_EQ(layer_1.LabelLength().cpu_data()[0], 2); + EXPECT_EQ(layer_1.LabelLength().cpu_data()[1], 2); + + // check loss for all other t + // (to check Graves Eq. (7.27) that holds for all t) + for (int t = 1; t < T_; ++t) { + layer_1.SetLossCalculationT(t); + layer_1.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + + const Dtype loss_t = blob_top_loss_->cpu_data()[0]; + EXPECT_FLOAT_EQ(loss, loss_t); + } + + + vector prob_down(0, false); + prob_down[0] = true; + layer_1.Backward(this->blob_top_vec_, prob_down, this->blob_bottom_vec_); + + // expected output of gradient is softmax of input + for (int i = 0; i < blob_bottom_data_->count(); ++i) { + const Dtype d = blob_bottom_data_->cpu_data()[i]; + const Dtype dd = blob_bottom_data_->cpu_diff()[i]; + if (d == std::numeric_limits::max()) { + EXPECT_FLOAT_EQ(dd, 1); + } else { + EXPECT_FLOAT_EQ(dd, 0); + } + } + } + + // Returns value of Binomial Coefficient C(n, k) + int binomialCoeff(int n, int k) { + int res = 1; + + // Since C(n, k) = C(n, n-k) + if ( k > n - k ) + k = n - k; + + // Calculate value of [n * (n-1) *---* (n-k+1)] / [k * (k-1) *----* 1] + for (int i = 0; i < k; ++i) { + res *= (n - i); + res /= (i + 1); + } + + return res; + } + + void TestForwardEqual() { + int sequence_length = 4; + InitConstantEqual(sequence_length, 20, 5); + LayerParameter layer_param; + + CTCLossLayer layer_1(layer_param); + layer_1.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + const Dtype loss_weight_1 = + layer_1.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + + const Dtype loss = blob_top_loss_->cpu_data()[0]; + + // the expected loss + Dtype expected_loss = -log(binomialCoeff(T_ + sequence_length, + 2 * sequence_length) + / pow(num_labels_, T_)); + EXPECT_FLOAT_EQ(loss, expected_loss); + + // check loss for all other t + // (to check Graves Eq. (7.27) that holds for all t) + for (int t = 1; t < T_; ++t) { + layer_1.SetLossCalculationT(t); + layer_1.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + + const Dtype loss_t = blob_top_loss_->cpu_data()[0]; + EXPECT_FLOAT_EQ(loss, loss_t); + } + } + + void TestForwardRandom() { + InitRandom(); + + LayerParameter layer_param; + CTCLossLayer layer_1(layer_param); + layer_1.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + const Dtype loss_weight_1 = + layer_1.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + + const Dtype loss = blob_top_loss_->cpu_data()[0]; + + // check loss for all other t + // (to check Graves Eq. (7.27) that holds for all t) + for (int t = 1; t < T_; ++t) { + layer_1.SetLossCalculationT(t); + layer_1.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + + const Dtype loss_t = blob_top_loss_->cpu_data()[0]; + + EXPECT_DOUBLE_EQ(loss, loss_t) << " at loss calculation for t = " << t; + } + } + + void TestGradient() { + // Input and ground truth for gradient from Alex Graves' implementation + // (taken from Tensorflow test) + Reshape(5, 2, 6); + + const Dtype targets_0[5] = {0, 1, 2, 1, 0}; + const Dtype loss_log_prob_0 = -3.34211; + + const Dtype input_prob_matrix_0[5 * 6] = + {0.633766, 0.221185, 0.0917319, 0.0129757, 0.0142857, 0.0260553, + 0.111121, 0.588392, 0.278779, 0.0055756, 0.00569609, 0.010436, + 0.0357786, 0.633813, 0.321418, 0.00249248, 0.00272882, 0.0037688, + 0.0663296, 0.643849, 0.280111, 0.00283995, 0.0035545, 0.00331533, + 0.458235, 0.396634, 0.123377, 0.00648837, 0.00903441, 0.00623107}; + vector input_log_prob_matrix_0(5 * 6); + + for (int i = 0; i < input_log_prob_matrix_0.size(); ++i) { + input_log_prob_matrix_0[i] = log(input_prob_matrix_0[i]); + } + + const Dtype gradient_log_prob_0[5 * 6] = + {-0.366234, 0.221185, 0.0917319, 0.0129757, 0.0142857, 0.0260553, + 0.111121, -0.411608, 0.278779, 0.0055756, 0.00569609, 0.010436, + 0.0357786, 0.633813, -0.678582, 0.00249248, 0.00272882, 0.0037688, + 0.0663296, -0.356151, 0.280111, 0.00283995, 0.0035545, 0.00331533, + -0.541765, 0.396634, 0.123377, 0.00648837, 0.00903441, 0.00623107}; + + const Dtype targets_1[4] = {0, 1, 1, 0}; + const Dtype loss_log_prob_1 = -5.42262; + + const Dtype input_prob_matrix_1[5 * 6] = + {0.30176, 0.28562, 0.0831517, 0.0862751, 0.0816851, 0.161508, + 0.24082, 0.397533, 0.0557226, 0.0546814, 0.0557528, 0.19549, + 0.230246, 0.450868, 0.0389607, 0.038309, 0.0391602, 0.202456, + 0.280884, 0.429522, 0.0326593, 0.0339046, 0.0326856, 0.190345, + 0.423286, 0.315517, 0.0338439, 0.0393744, 0.0339315, 0.154046}; + + vector input_log_prob_matrix_1(5 * 6); + + for (int i = 0; i < input_log_prob_matrix_1.size(); ++i) { + input_log_prob_matrix_1[i] = log(input_prob_matrix_1[i]); + } + + const Dtype gradient_log_prob_1[5 * 6] = + {-0.69824, 0.28562, 0.0831517, 0.0862751, 0.0816851, 0.161508, + 0.24082, -0.602467, 0.0557226, 0.0546814, 0.0557528, 0.19549, + 0.230246, 0.450868, 0.0389607, 0.038309, 0.0391602, -0.797544, + 0.280884, -0.570478, 0.0326593, 0.0339046, 0.0326856, 0.190345, + -0.576714, 0.315517, 0.0338439, 0.0393744, 0.0339315, 0.154046}; + + Dtype *data = blob_bottom_data_->mutable_cpu_data(); + for (int t = 0; t < T_; ++t) { + for (int c = 0; c < num_labels_; ++c) { + data[blob_bottom_data_->offset(t, 0, c)] + = input_log_prob_matrix_0[t * num_labels_ + c]; + data[blob_bottom_data_->offset(t, 1, c)] + = input_log_prob_matrix_1[t * num_labels_ + c]; + } + } + + FillerParameter filler_c1_param; + filler_c1_param.set_value(1); + ConstantFiller c1_filler(filler_c1_param); + c1_filler.Fill(blob_bottom_seq_ind_); + for (int n = 0; n < N_; ++n) { + blob_bottom_seq_ind_->mutable_cpu_data()[n] = 0; + } + + FillerParameter filler_cn1_param; + filler_cn1_param.set_value(-1); + ConstantFiller cn1_filler(filler_cn1_param); + cn1_filler.Fill(blob_bottom_label_); + Dtype* label_data = blob_bottom_label_->mutable_cpu_data(); + for (int t = 0; t < 5; ++t) { + label_data[blob_bottom_label_->offset(t, 0)] = targets_0[t]; + } + + for (int t = 0; t < 4; ++t) { + label_data[blob_bottom_label_->offset(t, 1)] = targets_1[t]; + } + + LayerParameter layer_param; + CTCLossLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + + // forward AND backward pass + const Dtype loss_weight_1 = + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + + + const Dtype loss = blob_top_loss_->cpu_data()[0]; + EXPECT_LE(abs((-loss) - (loss_log_prob_0 + loss_log_prob_1) / 2), + 0.000001); + + // check loss for all other t + // (to check Graves Eq. (7.27) that holds for all t) + for (int t = 1; t < T_; ++t) { + layer.SetLossCalculationT(t); + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + + const Dtype loss_t = blob_top_loss_->cpu_data()[0]; + + EXPECT_FLOAT_EQ(loss, loss_t) + << " at loss calculation for t = " << t; + } + + vector prob_down(0, false); + prob_down[0] = true; + layer.Backward(this->blob_top_vec_, prob_down, this->blob_bottom_vec_); + + const Dtype *diff = blob_bottom_data_->cpu_diff(); + for (int t = 0; t < T_; ++t) { + for (int c = 0; c < num_labels_; ++c) { + EXPECT_LE(std::abs(diff[blob_bottom_data_->offset(t, 0, c)] + - gradient_log_prob_0[t * num_labels_ + c]), + 0.000001); + EXPECT_LE(std::abs(diff[blob_bottom_data_->offset(t, 1, c)] + - gradient_log_prob_1[t * num_labels_ + c]), + 0.000001); + } + } + } + + int T_; + int N_; + int num_labels_; + Blob* const blob_bottom_data_; + Blob* const blob_bottom_label_; + Blob* const blob_bottom_seq_ind_; + Blob* const blob_top_loss_; + vector*> blob_bottom_vec_; + vector*> blob_top_vec_; +}; + +TYPED_TEST_CASE(CTCLossLayerTest, TestDtypesAndDevices); + +TYPED_TEST(CTCLossLayerTest, TestForward) { + this->TestForward(); +} + +TYPED_TEST(CTCLossLayerTest, TestForwardWrong) { + this->TestForwardWrong(); +} + +TYPED_TEST(CTCLossLayerTest, TestForwardEqual) { + this->TestForwardEqual(); +} + +TYPED_TEST(CTCLossLayerTest, TestForwardRandom) { + this->TestForwardRandom(); +} + +TYPED_TEST(CTCLossLayerTest, TestGradient) { + this->TestGradient(); +} + + +} // namespace caffe diff --git a/src/caffe/test/test_reverse_layer.cpp b/src/caffe/test/test_reverse_layer.cpp new file mode 100644 index 00000000000..8101271044d --- /dev/null +++ b/src/caffe/test/test_reverse_layer.cpp @@ -0,0 +1,366 @@ +#include + +#include "gtest/gtest.h" + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/filler.hpp" +#include "caffe/layers/reverse_layer.hpp" + +#include "caffe/test/test_caffe_main.hpp" +#include "caffe/test/test_gradient_check_util.hpp" + +namespace caffe { + +#ifndef CPU_ONLY +extern cudaDeviceProp CAFFE_TEST_CUDA_PROP; +#endif + +template +class ReverseLayerTest : public MultiDeviceTest { + typedef typename TypeParam::Dtype Dtype; + + protected: + ReverseLayerTest() + : blob_bottom_(new Blob(2, 3, 4, 1)), + blob_top_(new Blob()) { + // fill the values + FillerParameter filler_param; + UniformFiller filler(filler_param); + filler.Fill(this->blob_bottom_); + blob_top_vec_.push_back(blob_top_); + blob_bottom_vec_.push_back(blob_bottom_); + } + + + virtual ~ReverseLayerTest() { + delete blob_bottom_; + delete blob_top_; + } + + void TestForwardAxis( + int axis, + int s0, + int s1, + int s2, + int s3, + const Dtype data_in[], + const Dtype data_expected[]) { + this->blob_bottom_vec_.push_back(this->blob_bottom_); + LayerParameter layer_param; + ReverseParameter* reverse_param = + layer_param.mutable_reverse_param(); + + reverse_param->set_axis(axis); + + shared_ptr > layer( + new ReverseLayer(layer_param)); + + // create dummy data and diff + blob_bottom_->Reshape(5, 2, 1, 3); + blob_top_->ReshapeLike(*blob_bottom_); + + // copy input data + caffe_copy(blob_bottom_->count(), data_in, + blob_bottom_->mutable_cpu_data()); + + // Forward data + layer->Forward(blob_bottom_vec_, blob_top_vec_); + + // Output of top must match the expected data + EXPECT_EQ(blob_bottom_->count(), blob_top_->count()); + + for (int i = 0; i < blob_top_->count(); ++i) { + EXPECT_FLOAT_EQ(data_expected[i], blob_top_->cpu_data()[i]); + } + } + + void TestBackwardAxis( + int axis, + int s0, + int s1, + int s2, + int s3, + const Dtype diff_in[], + const Dtype diff_expected[]) { + this->blob_bottom_vec_.push_back(this->blob_bottom_); + LayerParameter layer_param; + ReverseParameter* reverse_param = + layer_param.mutable_reverse_param(); + reverse_param->set_axis(axis); + + shared_ptr > layer( + new ReverseLayer(layer_param)); + + // create dummy data and diff + blob_bottom_->Reshape(5, 2, 1, 3); + blob_top_->ReshapeLike(*blob_bottom_); + + // copy input diff + caffe_copy(blob_top_->count(), diff_in, blob_top_->mutable_cpu_diff()); + + // Backward diff + layer->Backward(blob_top_vec_, vector(1, true), blob_bottom_vec_); + + // Output of top must match the expected data + EXPECT_EQ(blob_bottom_->count(), blob_top_->count()); + + for (int i = 0; i < blob_top_->count(); ++i) { + EXPECT_FLOAT_EQ(diff_expected[i], blob_bottom_->cpu_diff()[i]); + } + } + + Blob* const blob_bottom_; + Blob* const blob_top_; + vector*> blob_bottom_vec_; + vector*> blob_top_vec_; +}; + +TYPED_TEST_CASE(ReverseLayerTest, TestDtypesAndDevices); + +TYPED_TEST(ReverseLayerTest, TestForwardAxisZero) { + typedef typename TypeParam::Dtype Dtype; + const Dtype data_in[5 * 2 * 1 * 3] = { + 1, 2, 3, 4, 5, 6, + 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, + 19, 20, 21, 22, 23, 24, + 25, 26, 27, 28, 29, 30 + }; + + // first axis must be inverted + const Dtype data_expected[5 * 2 * 1 * 3] = { + 25, 26, 27, 28, 29, 30, + 19, 20, 21, 22, 23, 24, + 13, 14, 15, 16, 17, 18, + 7, 8, 9, 10, 11, 12, + 1, 2, 3, 4, 5, 6 + }; + + + this->TestForwardAxis(0, 5, 2, 1, 3, data_in, data_expected); +} + +TYPED_TEST(ReverseLayerTest, TestBackwardAxisZero) { + typedef typename TypeParam::Dtype Dtype; + const Dtype diff_in[5 * 2 * 1 * 3] = { + 100, 101, 102, 103, 104, 105, + 106, 107, 108, 109, 110, 111, + 112, 113, 114, 115, 116, 117, + 118, 119, 120, 121, 122, 123, + 124, 125, 126, 127, 128, 129 + }; + + // first axis must be inverted + const Dtype diff_expected[5 * 2 * 1 * 3] = { + 124, 125, 126, 127, 128, 129, + 118, 119, 120, 121, 122, 123, + 112, 113, 114, 115, 116, 117, + 106, 107, 108, 109, 110, 111, + 100, 101, 102, 103, 104, 105 + }; + + this->TestBackwardAxis(0, 5, 2, 1, 3, diff_in, diff_expected); +} + +TYPED_TEST(ReverseLayerTest, TestForwardAxisOne) { + typedef typename TypeParam::Dtype Dtype; + const Dtype data_in[5 * 2 * 1 * 3] = { + 1, 2, 3, + 4, 5, 6, + + 7, 8, 9, + 10, 11, 12, + + 13, 14, 15, + 16, 17, 18, + + 19, 20, 21, + 22, 23, 24, + + 25, 26, 27, + 28, 29, 30 + }; + + // second axis must be inverted + const Dtype data_expected[5 * 2 * 1 * 3] = { + 4, 5, 6, + 1, 2, 3, + + 10, 11, 12, + 7, 8, 9, + + 16, 17, 18, + 13, 14, 15, + + 22, 23, 24, + 19, 20, 21, + + 28, 29, 30, + 25, 26, 27 + }; + + this->TestForwardAxis(1, 5, 2, 1, 3, data_in, data_expected); +} + +TYPED_TEST(ReverseLayerTest, TestBackwardAxisOne) { + typedef typename TypeParam::Dtype Dtype; + const Dtype diff_in[5 * 2 * 1 * 3] = { + 100, 101, 102, + 103, 104, 105, + 106, 107, 108, + 109, 110, 111, + 112, 113, 114, + 115, 116, 117, + 118, 119, 120, + 121, 122, 123, + 124, 125, 126, + 127, 128, 129 + }; + + // first axis must be inverted + const Dtype diff_expected[5 * 2 * 1 * 3] = { + 103, 104, 105, + 100, 101, 102, + 109, 110, 111, + 106, 107, 108, + 115, 116, 117, + 112, 113, 114, + 121, 122, 123, + 118, 119, 120, + 127, 128, 129, + 124, 125, 126 + }; + + this->TestBackwardAxis(1, 5, 2, 1, 3, diff_in, diff_expected); +} + +TYPED_TEST(ReverseLayerTest, TestForwardAxisTwo) { + typedef typename TypeParam::Dtype Dtype; + const Dtype data_in[5 * 2 * 1 * 3] = { + 1, 2, 3, + 4, 5, 6, + 7, 8, 9, + 10, 11, 12, + 13, 14, 15, + 16, 17, 18, + 19, 20, 21, + 22, 23, 24, + 25, 26, 27, + 28, 29, 30 + }; + + // second axis must be inverted + const Dtype data_expected[5 * 2 * 1 * 3] = { + 1, 2, 3, + 4, 5, 6, + 7, 8, 9, + 10, 11, 12, + 13, 14, 15, + 16, 17, 18, + 19, 20, 21, + 22, 23, 24, + 25, 26, 27, + 28, 29, 30 + }; + + this->TestForwardAxis(2, 5, 2, 1, 3, data_in, data_expected); +} + +TYPED_TEST(ReverseLayerTest, TestBackwardAxisTwo) { + typedef typename TypeParam::Dtype Dtype; + const Dtype diff_in[5 * 2 * 1 * 3] = { + 100, 101, 102, + 103, 104, 105, + 106, 107, 108, + 109, 110, 111, + 112, 113, 114, + 115, 116, 117, + 118, 119, 120, + 121, 122, 123, + 124, 125, 126, + 127, 128, 129 + }; + + // first axis must be inverted + const Dtype diff_expected[5 * 2 * 1 * 3] = { + 100, 101, 102, + 103, 104, 105, + 106, 107, 108, + 109, 110, 111, + 112, 113, 114, + 115, 116, 117, + 118, 119, 120, + 121, 122, 123, + 124, 125, 126, + 127, 128, 129 + }; + + this->TestBackwardAxis(2, 5, 2, 1, 3, diff_in, diff_expected); +} + +TYPED_TEST(ReverseLayerTest, TestForwardAxisThree) { + typedef typename TypeParam::Dtype Dtype; + const Dtype data_in[5 * 2 * 1 * 3] = { + 1, 2, 3, + 4, 5, 6, + 7, 8, 9, + 10, 11, 12, + 13, 14, 15, + 16, 17, 18, + 19, 20, 21, + 22, 23, 24, + 25, 26, 27, + 28, 29, 30 + }; + + // second axis must be inverted + const Dtype data_expected[5 * 2 * 1 * 3] = { + 3, 2, 1, + 6, 5, 4, + 9, 8, 7, + 12, 11, 10, + 15, 14, 13, + 18, 17, 16, + 21, 20, 19, + 24, 23, 22, + 27, 26, 25, + 30, 29, 28 + }; + + this->TestForwardAxis(3, 5, 2, 1, 3, data_in, data_expected); +} + +TYPED_TEST(ReverseLayerTest, TestBackwardAxisThree) { + typedef typename TypeParam::Dtype Dtype; + const Dtype diff_in[5 * 2 * 1 * 3] = { + 100, 101, 102, + 103, 104, 105, + 106, 107, 108, + 109, 110, 111, + 112, 113, 114, + 115, 116, 117, + 118, 119, 120, + 121, 122, 123, + 124, 125, 126, + 127, 128, 129 + }; + + // first axis must be inverted + const Dtype diff_expected[5 * 2 * 1 * 3] = { + 102, 101, 100, + 105, 104, 103, + 108, 107, 106, + 111, 110, 109, + 114, 113, 112, + 117, 116, 115, + 120, 119, 118, + 123, 122, 121, + 126, 125, 124, + 129, 128, 127 + }; + + this->TestBackwardAxis(3, 5, 2, 1, 3, diff_in, diff_expected); +} + +} // namespace caffe