diff --git a/tensorflow/contrib/lite/c/builtin_op_data.h b/tensorflow/contrib/lite/c/builtin_op_data.h index 44daf7adaa0e76..1e65c3cee27798 100644 --- a/tensorflow/contrib/lite/c/builtin_op_data.h +++ b/tensorflow/contrib/lite/c/builtin_op_data.h @@ -186,6 +186,13 @@ typedef struct { TfLiteLSTMKernelType kernel_type; } TfLiteLSTMParams; +typedef struct { + // Parameters for the LSTM kernel. + TfLiteFusedActivation activation; + float cell_clip; + float proj_clip; +} TfLiteUnidirectionalSequenceLSTMParams; + typedef struct { // Parameters for the LSTM kernel. TfLiteFusedActivation activation; diff --git a/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc index eac7db9a88d2ad..b092e5ee547805 100644 --- a/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc +++ b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc @@ -371,7 +371,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = reinterpret_cast(params); break; } - case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM: case BuiltinOperator_LSTM: { auto params = allocator->AllocatePOD(); if (auto* lstm_params = op->builtin_options_as_LSTMOptions()) { @@ -391,6 +390,20 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = reinterpret_cast(params); break; } + case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM: { + auto* params = + allocator->AllocatePOD(); + if (auto* seq_lstm_params = + op->builtin_options_as_UnidirectionalSequenceLSTMOptions()) { + params->activation = + parse_activation(seq_lstm_params->fused_activation_function()); + params->cell_clip = seq_lstm_params->cell_clip(); + params->proj_clip = seq_lstm_params->proj_clip(); + } + *builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM: { auto params = allocator->AllocatePOD(); diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc index ec9cf38b831c22..89d57e45993af9 100644 --- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc +++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc @@ -431,7 +431,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { } TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { - auto* params = reinterpret_cast(node->builtin_data); + const auto* params = + reinterpret_cast( + node->builtin_data); const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input_to_input_weights = @@ -482,6 +484,12 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + // Copy out the LSTM specific params so they can be passed in the function. + TfLiteLSTMParams lstm_params; + lstm_params.activation = params->activation; + lstm_params.cell_clip = params->cell_clip; + lstm_params.proj_clip = params->proj_clip; + switch (input_to_output_weights->type) { case kTfLiteFloat32: { return lstm_eval::EvalFloat( @@ -496,7 +504,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { /*aux_input_to_cell_weights=*/nullptr, /*aux_input_to_output_weights=*/nullptr, input_gate_bias, forget_gate_bias, cell_bias, output_gate_bias, projection_weights, - projection_bias, params, /*forward_sequence=*/true, + projection_bias, &lstm_params, /*forward_sequence=*/true, /*output_offset=*/0, scratch_buffer, activation_state, cell_state, output); } @@ -523,7 +531,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { /*aux_input_to_cell_weights=*/nullptr, /*aux_input_to_output_weights=*/nullptr, input_gate_bias, forget_gate_bias, cell_bias, output_gate_bias, projection_weights, - projection_bias, params, /*forward_sequence=*/true, + projection_bias, &lstm_params, /*forward_sequence=*/true, /*output_offset=*/0, scratch_buffer, scaling_factors, prod_scaling_factors, recovered_cell_weights, input_quantized, /*aux_input_quantized=*/nullptr, activation_state_quantized, diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc index cd3aac053262c3..c97b0fdd612497 100644 --- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc +++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc @@ -110,11 +110,12 @@ class UnidirectionalLSTMOpModel : public SingleOpModel { output_ = AddOutput(TensorType_FLOAT32); - SetBuiltinOp(BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM, - BuiltinOptions_LSTMOptions, - CreateLSTMOptions(builder_, ActivationFunctionType_TANH, - cell_clip, proj_clip) - .Union()); + SetBuiltinOp( + BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM, + BuiltinOptions_UnidirectionalSequenceLSTMOptions, + CreateUnidirectionalSequenceLSTMOptions( + builder_, ActivationFunctionType_TANH, cell_clip, proj_clip) + .Union()); BuildInterpreter(input_shapes); } diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs index ff8430827c7849..cb7a2827433083 100644 --- a/tensorflow/contrib/lite/schema/schema.fbs +++ b/tensorflow/contrib/lite/schema/schema.fbs @@ -250,6 +250,7 @@ union BuiltinOptions { FillOptions, BidirectionalSequenceLSTMOptions, BidirectionalSequenceRNNOptions, + UnidirectionalSequenceLSTMOptions, } enum Padding : byte { SAME, VALID } @@ -394,6 +395,13 @@ table LSTMOptions { kernel_type: LSTMKernelType = FULL; } +// An implementation of TensorFlow dynamic_rnn with LSTMCell. +table UnidirectionalSequenceLSTMOptions { + fused_activation_function:ActivationFunctionType; + cell_clip: float; // Optional, 0.0 means no clipping + proj_clip: float; // Optional, 0.0 means no clipping +} + table BidirectionalSequenceLSTMOptions { fused_activation_function:ActivationFunctionType; cell_clip: float; // Optional, 0.0 means no clipping diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h index f3cb113c9c58f8..e7b7a59def3e4f 100755 --- a/tensorflow/contrib/lite/schema/schema_generated.h +++ b/tensorflow/contrib/lite/schema/schema_generated.h @@ -79,6 +79,9 @@ struct LocalResponseNormalizationOptionsT; struct LSTMOptions; struct LSTMOptionsT; +struct UnidirectionalSequenceLSTMOptions; +struct UnidirectionalSequenceLSTMOptionsT; + struct BidirectionalSequenceLSTMOptions; struct BidirectionalSequenceLSTMOptionsT; @@ -681,11 +684,12 @@ enum BuiltinOptions { BuiltinOptions_FillOptions = 68, BuiltinOptions_BidirectionalSequenceLSTMOptions = 69, BuiltinOptions_BidirectionalSequenceRNNOptions = 70, + BuiltinOptions_UnidirectionalSequenceLSTMOptions = 71, BuiltinOptions_MIN = BuiltinOptions_NONE, - BuiltinOptions_MAX = BuiltinOptions_BidirectionalSequenceRNNOptions + BuiltinOptions_MAX = BuiltinOptions_UnidirectionalSequenceLSTMOptions }; -inline const BuiltinOptions (&EnumValuesBuiltinOptions())[71] { +inline const BuiltinOptions (&EnumValuesBuiltinOptions())[72] { static const BuiltinOptions values[] = { BuiltinOptions_NONE, BuiltinOptions_Conv2DOptions, @@ -757,7 +761,8 @@ inline const BuiltinOptions (&EnumValuesBuiltinOptions())[71] { BuiltinOptions_ZerosLikeOptions, BuiltinOptions_FillOptions, BuiltinOptions_BidirectionalSequenceLSTMOptions, - BuiltinOptions_BidirectionalSequenceRNNOptions + BuiltinOptions_BidirectionalSequenceRNNOptions, + BuiltinOptions_UnidirectionalSequenceLSTMOptions }; return values; } @@ -835,6 +840,7 @@ inline const char * const *EnumNamesBuiltinOptions() { "FillOptions", "BidirectionalSequenceLSTMOptions", "BidirectionalSequenceRNNOptions", + "UnidirectionalSequenceLSTMOptions", nullptr }; return names; @@ -1129,6 +1135,10 @@ template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_BidirectionalSequenceRNNOptions; }; +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_UnidirectionalSequenceLSTMOptions; +}; + struct BuiltinOptionsUnion { BuiltinOptions type; void *value; @@ -1720,6 +1730,14 @@ struct BuiltinOptionsUnion { return type == BuiltinOptions_BidirectionalSequenceRNNOptions ? reinterpret_cast(value) : nullptr; } + UnidirectionalSequenceLSTMOptionsT *AsUnidirectionalSequenceLSTMOptions() { + return type == BuiltinOptions_UnidirectionalSequenceLSTMOptions ? + reinterpret_cast(value) : nullptr; + } + const UnidirectionalSequenceLSTMOptionsT *AsUnidirectionalSequenceLSTMOptions() const { + return type == BuiltinOptions_UnidirectionalSequenceLSTMOptions ? + reinterpret_cast(value) : nullptr; + } }; bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type); @@ -3469,6 +3487,84 @@ inline flatbuffers::Offset CreateLSTMOptions( flatbuffers::Offset CreateLSTMOptions(flatbuffers::FlatBufferBuilder &_fbb, const LSTMOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +struct UnidirectionalSequenceLSTMOptionsT : public flatbuffers::NativeTable { + typedef UnidirectionalSequenceLSTMOptions TableType; + ActivationFunctionType fused_activation_function; + float cell_clip; + float proj_clip; + UnidirectionalSequenceLSTMOptionsT() + : fused_activation_function(ActivationFunctionType_NONE), + cell_clip(0.0f), + proj_clip(0.0f) { + } +}; + +struct UnidirectionalSequenceLSTMOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef UnidirectionalSequenceLSTMOptionsT NativeTableType; + enum { + VT_FUSED_ACTIVATION_FUNCTION = 4, + VT_CELL_CLIP = 6, + VT_PROJ_CLIP = 8 + }; + ActivationFunctionType fused_activation_function() const { + return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + } + float cell_clip() const { + return GetField(VT_CELL_CLIP, 0.0f); + } + float proj_clip() const { + return GetField(VT_PROJ_CLIP, 0.0f); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION) && + VerifyField(verifier, VT_CELL_CLIP) && + VerifyField(verifier, VT_PROJ_CLIP) && + verifier.EndTable(); + } + UnidirectionalSequenceLSTMOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(UnidirectionalSequenceLSTMOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const UnidirectionalSequenceLSTMOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct UnidirectionalSequenceLSTMOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_fused_activation_function(ActivationFunctionType fused_activation_function) { + fbb_.AddElement(UnidirectionalSequenceLSTMOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); + } + void add_cell_clip(float cell_clip) { + fbb_.AddElement(UnidirectionalSequenceLSTMOptions::VT_CELL_CLIP, cell_clip, 0.0f); + } + void add_proj_clip(float proj_clip) { + fbb_.AddElement(UnidirectionalSequenceLSTMOptions::VT_PROJ_CLIP, proj_clip, 0.0f); + } + explicit UnidirectionalSequenceLSTMOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + UnidirectionalSequenceLSTMOptionsBuilder &operator=(const UnidirectionalSequenceLSTMOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateUnidirectionalSequenceLSTMOptions( + flatbuffers::FlatBufferBuilder &_fbb, + ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE, + float cell_clip = 0.0f, + float proj_clip = 0.0f) { + UnidirectionalSequenceLSTMOptionsBuilder builder_(_fbb); + builder_.add_proj_clip(proj_clip); + builder_.add_cell_clip(cell_clip); + builder_.add_fused_activation_function(fused_activation_function); + return builder_.Finish(); +} + +flatbuffers::Offset CreateUnidirectionalSequenceLSTMOptions(flatbuffers::FlatBufferBuilder &_fbb, const UnidirectionalSequenceLSTMOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + struct BidirectionalSequenceLSTMOptionsT : public flatbuffers::NativeTable { typedef BidirectionalSequenceLSTMOptions TableType; ActivationFunctionType fused_activation_function; @@ -6488,6 +6584,9 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const BidirectionalSequenceRNNOptions *builtin_options_as_BidirectionalSequenceRNNOptions() const { return builtin_options_type() == BuiltinOptions_BidirectionalSequenceRNNOptions ? static_cast(builtin_options()) : nullptr; } + const UnidirectionalSequenceLSTMOptions *builtin_options_as_UnidirectionalSequenceLSTMOptions() const { + return builtin_options_type() == BuiltinOptions_UnidirectionalSequenceLSTMOptions ? static_cast(builtin_options()) : nullptr; + } const flatbuffers::Vector *custom_options() const { return GetPointer *>(VT_CUSTOM_OPTIONS); } @@ -6799,6 +6898,10 @@ template<> inline const BidirectionalSequenceRNNOptions *Operator::builtin_optio return builtin_options_as_BidirectionalSequenceRNNOptions(); } +template<> inline const UnidirectionalSequenceLSTMOptions *Operator::builtin_options_as() const { + return builtin_options_as_UnidirectionalSequenceLSTMOptions(); +} + struct OperatorBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; @@ -7809,6 +7912,38 @@ inline flatbuffers::Offset CreateLSTMOptions(flatbuffers::FlatBuffe _kernel_type); } +inline UnidirectionalSequenceLSTMOptionsT *UnidirectionalSequenceLSTMOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new UnidirectionalSequenceLSTMOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void UnidirectionalSequenceLSTMOptions::UnPackTo(UnidirectionalSequenceLSTMOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }; + { auto _e = cell_clip(); _o->cell_clip = _e; }; + { auto _e = proj_clip(); _o->proj_clip = _e; }; +} + +inline flatbuffers::Offset UnidirectionalSequenceLSTMOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const UnidirectionalSequenceLSTMOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateUnidirectionalSequenceLSTMOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateUnidirectionalSequenceLSTMOptions(flatbuffers::FlatBufferBuilder &_fbb, const UnidirectionalSequenceLSTMOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const UnidirectionalSequenceLSTMOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _fused_activation_function = _o->fused_activation_function; + auto _cell_clip = _o->cell_clip; + auto _proj_clip = _o->proj_clip; + return tflite::CreateUnidirectionalSequenceLSTMOptions( + _fbb, + _fused_activation_function, + _cell_clip, + _proj_clip); +} + inline BidirectionalSequenceLSTMOptionsT *BidirectionalSequenceLSTMOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new BidirectionalSequenceLSTMOptionsT(); UnPackTo(_o, _resolver); @@ -9620,6 +9755,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } + case BuiltinOptions_UnidirectionalSequenceLSTMOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } default: return false; } } @@ -9918,6 +10057,10 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } + case BuiltinOptions_UnidirectionalSequenceLSTMOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } default: return nullptr; } } @@ -10204,6 +10347,10 @@ inline flatbuffers::Offset BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff auto ptr = reinterpret_cast(value); return CreateBidirectionalSequenceRNNOptions(_fbb, ptr, _rehasher).Union(); } + case BuiltinOptions_UnidirectionalSequenceLSTMOptions: { + auto ptr = reinterpret_cast(value); + return CreateUnidirectionalSequenceLSTMOptions(_fbb, ptr, _rehasher).Union(); + } default: return 0; } } @@ -10490,6 +10637,10 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL value = new BidirectionalSequenceRNNOptionsT(*reinterpret_cast(u.value)); break; } + case BuiltinOptions_UnidirectionalSequenceLSTMOptions: { + value = new UnidirectionalSequenceLSTMOptionsT(*reinterpret_cast(u.value)); + break; + } default: break; } @@ -10847,6 +10998,11 @@ inline void BuiltinOptionsUnion::Reset() { delete ptr; break; } + case BuiltinOptions_UnidirectionalSequenceLSTMOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } default: break; } value = nullptr;