Skip to content

Commit

Permalink
Add SequenceLSTMOptions to schema to decouple the sequential Op from …
Browse files Browse the repository at this point in the history
…the LSTM.

PiperOrigin-RevId: 216066634
  • Loading branch information
tensorflower-gardener committed Oct 7, 2018
1 parent e93a189 commit 7fa6a6b
Show file tree
Hide file tree
Showing 6 changed files with 205 additions and 12 deletions.
7 changes: 7 additions & 0 deletions tensorflow/contrib/lite/c/builtin_op_data.h
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,13 @@ typedef struct {
TfLiteLSTMKernelType kernel_type;
} TfLiteLSTMParams;

typedef struct {
// Parameters for the LSTM kernel.
TfLiteFusedActivation activation;
float cell_clip;
float proj_clip;
} TfLiteUnidirectionalSequenceLSTMParams;

typedef struct {
// Parameters for the LSTM kernel.
TfLiteFusedActivation activation;
Expand Down
15 changes: 14 additions & 1 deletion tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
*builtin_data = reinterpret_cast<void*>(params);
break;
}
case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM:
case BuiltinOperator_LSTM: {
auto params = allocator->AllocatePOD<TfLiteLSTMParams>();
if (auto* lstm_params = op->builtin_options_as_LSTMOptions()) {
Expand All @@ -391,6 +390,20 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
*builtin_data = reinterpret_cast<void*>(params);
break;
}
case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM: {
auto* params =
allocator->AllocatePOD<TfLiteUnidirectionalSequenceLSTMParams>();
if (auto* seq_lstm_params =
op->builtin_options_as_UnidirectionalSequenceLSTMOptions()) {
params->activation =
parse_activation(seq_lstm_params->fused_activation_function());
params->cell_clip = seq_lstm_params->cell_clip();
params->proj_clip = seq_lstm_params->proj_clip();
}
*builtin_data = reinterpret_cast<void*>(params);
break;
}

case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM: {
auto params =
allocator->AllocatePOD<TfLiteBidirectionalSequenceLSTMParams>();
Expand Down
14 changes: 11 additions & 3 deletions tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}

TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
const auto* params =
reinterpret_cast<TfLiteUnidirectionalSequenceLSTMParams*>(
node->builtin_data);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);

const TfLiteTensor* input_to_input_weights =
Expand Down Expand Up @@ -482,6 +484,12 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {

TfLiteTensor* output = GetOutput(context, node, kOutputTensor);

// Copy out the LSTM specific params so they can be passed in the function.
TfLiteLSTMParams lstm_params;
lstm_params.activation = params->activation;
lstm_params.cell_clip = params->cell_clip;
lstm_params.proj_clip = params->proj_clip;

switch (input_to_output_weights->type) {
case kTfLiteFloat32: {
return lstm_eval::EvalFloat(
Expand All @@ -496,7 +504,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
/*aux_input_to_cell_weights=*/nullptr,
/*aux_input_to_output_weights=*/nullptr, input_gate_bias,
forget_gate_bias, cell_bias, output_gate_bias, projection_weights,
projection_bias, params, /*forward_sequence=*/true,
projection_bias, &lstm_params, /*forward_sequence=*/true,
/*output_offset=*/0, scratch_buffer, activation_state, cell_state,
output);
}
Expand All @@ -523,7 +531,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
/*aux_input_to_cell_weights=*/nullptr,
/*aux_input_to_output_weights=*/nullptr, input_gate_bias,
forget_gate_bias, cell_bias, output_gate_bias, projection_weights,
projection_bias, params, /*forward_sequence=*/true,
projection_bias, &lstm_params, /*forward_sequence=*/true,
/*output_offset=*/0, scratch_buffer, scaling_factors,
prod_scaling_factors, recovered_cell_weights, input_quantized,
/*aux_input_quantized=*/nullptr, activation_state_quantized,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,12 @@ class UnidirectionalLSTMOpModel : public SingleOpModel {

output_ = AddOutput(TensorType_FLOAT32);

SetBuiltinOp(BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM,
BuiltinOptions_LSTMOptions,
CreateLSTMOptions(builder_, ActivationFunctionType_TANH,
cell_clip, proj_clip)
.Union());
SetBuiltinOp(
BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM,
BuiltinOptions_UnidirectionalSequenceLSTMOptions,
CreateUnidirectionalSequenceLSTMOptions(
builder_, ActivationFunctionType_TANH, cell_clip, proj_clip)
.Union());
BuildInterpreter(input_shapes);
}

Expand Down
8 changes: 8 additions & 0 deletions tensorflow/contrib/lite/schema/schema.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ union BuiltinOptions {
FillOptions,
BidirectionalSequenceLSTMOptions,
BidirectionalSequenceRNNOptions,
UnidirectionalSequenceLSTMOptions,
}

enum Padding : byte { SAME, VALID }
Expand Down Expand Up @@ -394,6 +395,13 @@ table LSTMOptions {
kernel_type: LSTMKernelType = FULL;
}

// An implementation of TensorFlow dynamic_rnn with LSTMCell.
table UnidirectionalSequenceLSTMOptions {
fused_activation_function:ActivationFunctionType;
cell_clip: float; // Optional, 0.0 means no clipping
proj_clip: float; // Optional, 0.0 means no clipping
}

table BidirectionalSequenceLSTMOptions {
fused_activation_function:ActivationFunctionType;
cell_clip: float; // Optional, 0.0 means no clipping
Expand Down
162 changes: 159 additions & 3 deletions tensorflow/contrib/lite/schema/schema_generated.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ struct LocalResponseNormalizationOptionsT;
struct LSTMOptions;
struct LSTMOptionsT;

struct UnidirectionalSequenceLSTMOptions;
struct UnidirectionalSequenceLSTMOptionsT;

struct BidirectionalSequenceLSTMOptions;
struct BidirectionalSequenceLSTMOptionsT;

Expand Down Expand Up @@ -681,11 +684,12 @@ enum BuiltinOptions {
BuiltinOptions_FillOptions = 68,
BuiltinOptions_BidirectionalSequenceLSTMOptions = 69,
BuiltinOptions_BidirectionalSequenceRNNOptions = 70,
BuiltinOptions_UnidirectionalSequenceLSTMOptions = 71,
BuiltinOptions_MIN = BuiltinOptions_NONE,
BuiltinOptions_MAX = BuiltinOptions_BidirectionalSequenceRNNOptions
BuiltinOptions_MAX = BuiltinOptions_UnidirectionalSequenceLSTMOptions
};

inline const BuiltinOptions (&EnumValuesBuiltinOptions())[71] {
inline const BuiltinOptions (&EnumValuesBuiltinOptions())[72] {
static const BuiltinOptions values[] = {
BuiltinOptions_NONE,
BuiltinOptions_Conv2DOptions,
Expand Down Expand Up @@ -757,7 +761,8 @@ inline const BuiltinOptions (&EnumValuesBuiltinOptions())[71] {
BuiltinOptions_ZerosLikeOptions,
BuiltinOptions_FillOptions,
BuiltinOptions_BidirectionalSequenceLSTMOptions,
BuiltinOptions_BidirectionalSequenceRNNOptions
BuiltinOptions_BidirectionalSequenceRNNOptions,
BuiltinOptions_UnidirectionalSequenceLSTMOptions
};
return values;
}
Expand Down Expand Up @@ -835,6 +840,7 @@ inline const char * const *EnumNamesBuiltinOptions() {
"FillOptions",
"BidirectionalSequenceLSTMOptions",
"BidirectionalSequenceRNNOptions",
"UnidirectionalSequenceLSTMOptions",
nullptr
};
return names;
Expand Down Expand Up @@ -1129,6 +1135,10 @@ template<> struct BuiltinOptionsTraits<BidirectionalSequenceRNNOptions> {
static const BuiltinOptions enum_value = BuiltinOptions_BidirectionalSequenceRNNOptions;
};

template<> struct BuiltinOptionsTraits<UnidirectionalSequenceLSTMOptions> {
static const BuiltinOptions enum_value = BuiltinOptions_UnidirectionalSequenceLSTMOptions;
};

struct BuiltinOptionsUnion {
BuiltinOptions type;
void *value;
Expand Down Expand Up @@ -1720,6 +1730,14 @@ struct BuiltinOptionsUnion {
return type == BuiltinOptions_BidirectionalSequenceRNNOptions ?
reinterpret_cast<const BidirectionalSequenceRNNOptionsT *>(value) : nullptr;
}
UnidirectionalSequenceLSTMOptionsT *AsUnidirectionalSequenceLSTMOptions() {
return type == BuiltinOptions_UnidirectionalSequenceLSTMOptions ?
reinterpret_cast<UnidirectionalSequenceLSTMOptionsT *>(value) : nullptr;
}
const UnidirectionalSequenceLSTMOptionsT *AsUnidirectionalSequenceLSTMOptions() const {
return type == BuiltinOptions_UnidirectionalSequenceLSTMOptions ?
reinterpret_cast<const UnidirectionalSequenceLSTMOptionsT *>(value) : nullptr;
}
};

bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type);
Expand Down Expand Up @@ -3469,6 +3487,84 @@ inline flatbuffers::Offset<LSTMOptions> CreateLSTMOptions(

flatbuffers::Offset<LSTMOptions> CreateLSTMOptions(flatbuffers::FlatBufferBuilder &_fbb, const LSTMOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);

struct UnidirectionalSequenceLSTMOptionsT : public flatbuffers::NativeTable {
typedef UnidirectionalSequenceLSTMOptions TableType;
ActivationFunctionType fused_activation_function;
float cell_clip;
float proj_clip;
UnidirectionalSequenceLSTMOptionsT()
: fused_activation_function(ActivationFunctionType_NONE),
cell_clip(0.0f),
proj_clip(0.0f) {
}
};

struct UnidirectionalSequenceLSTMOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
typedef UnidirectionalSequenceLSTMOptionsT NativeTableType;
enum {
VT_FUSED_ACTIVATION_FUNCTION = 4,
VT_CELL_CLIP = 6,
VT_PROJ_CLIP = 8
};
ActivationFunctionType fused_activation_function() const {
return static_cast<ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
}
float cell_clip() const {
return GetField<float>(VT_CELL_CLIP, 0.0f);
}
float proj_clip() const {
return GetField<float>(VT_PROJ_CLIP, 0.0f);
}
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) &&
VerifyField<float>(verifier, VT_CELL_CLIP) &&
VerifyField<float>(verifier, VT_PROJ_CLIP) &&
verifier.EndTable();
}
UnidirectionalSequenceLSTMOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
void UnPackTo(UnidirectionalSequenceLSTMOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
static flatbuffers::Offset<UnidirectionalSequenceLSTMOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const UnidirectionalSequenceLSTMOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
};

struct UnidirectionalSequenceLSTMOptionsBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
void add_fused_activation_function(ActivationFunctionType fused_activation_function) {
fbb_.AddElement<int8_t>(UnidirectionalSequenceLSTMOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast<int8_t>(fused_activation_function), 0);
}
void add_cell_clip(float cell_clip) {
fbb_.AddElement<float>(UnidirectionalSequenceLSTMOptions::VT_CELL_CLIP, cell_clip, 0.0f);
}
void add_proj_clip(float proj_clip) {
fbb_.AddElement<float>(UnidirectionalSequenceLSTMOptions::VT_PROJ_CLIP, proj_clip, 0.0f);
}
explicit UnidirectionalSequenceLSTMOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
}
UnidirectionalSequenceLSTMOptionsBuilder &operator=(const UnidirectionalSequenceLSTMOptionsBuilder &);
flatbuffers::Offset<UnidirectionalSequenceLSTMOptions> Finish() {
const auto end = fbb_.EndTable(start_);
auto o = flatbuffers::Offset<UnidirectionalSequenceLSTMOptions>(end);
return o;
}
};

inline flatbuffers::Offset<UnidirectionalSequenceLSTMOptions> CreateUnidirectionalSequenceLSTMOptions(
flatbuffers::FlatBufferBuilder &_fbb,
ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE,
float cell_clip = 0.0f,
float proj_clip = 0.0f) {
UnidirectionalSequenceLSTMOptionsBuilder builder_(_fbb);
builder_.add_proj_clip(proj_clip);
builder_.add_cell_clip(cell_clip);
builder_.add_fused_activation_function(fused_activation_function);
return builder_.Finish();
}

flatbuffers::Offset<UnidirectionalSequenceLSTMOptions> CreateUnidirectionalSequenceLSTMOptions(flatbuffers::FlatBufferBuilder &_fbb, const UnidirectionalSequenceLSTMOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);

struct BidirectionalSequenceLSTMOptionsT : public flatbuffers::NativeTable {
typedef BidirectionalSequenceLSTMOptions TableType;
ActivationFunctionType fused_activation_function;
Expand Down Expand Up @@ -6488,6 +6584,9 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
const BidirectionalSequenceRNNOptions *builtin_options_as_BidirectionalSequenceRNNOptions() const {
return builtin_options_type() == BuiltinOptions_BidirectionalSequenceRNNOptions ? static_cast<const BidirectionalSequenceRNNOptions *>(builtin_options()) : nullptr;
}
const UnidirectionalSequenceLSTMOptions *builtin_options_as_UnidirectionalSequenceLSTMOptions() const {
return builtin_options_type() == BuiltinOptions_UnidirectionalSequenceLSTMOptions ? static_cast<const UnidirectionalSequenceLSTMOptions *>(builtin_options()) : nullptr;
}
const flatbuffers::Vector<uint8_t> *custom_options() const {
return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
}
Expand Down Expand Up @@ -6799,6 +6898,10 @@ template<> inline const BidirectionalSequenceRNNOptions *Operator::builtin_optio
return builtin_options_as_BidirectionalSequenceRNNOptions();
}

template<> inline const UnidirectionalSequenceLSTMOptions *Operator::builtin_options_as<UnidirectionalSequenceLSTMOptions>() const {
return builtin_options_as_UnidirectionalSequenceLSTMOptions();
}

struct OperatorBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
Expand Down Expand Up @@ -7809,6 +7912,38 @@ inline flatbuffers::Offset<LSTMOptions> CreateLSTMOptions(flatbuffers::FlatBuffe
_kernel_type);
}

inline UnidirectionalSequenceLSTMOptionsT *UnidirectionalSequenceLSTMOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new UnidirectionalSequenceLSTMOptionsT();
UnPackTo(_o, _resolver);
return _o;
}

inline void UnidirectionalSequenceLSTMOptions::UnPackTo(UnidirectionalSequenceLSTMOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
(void)_o;
(void)_resolver;
{ auto _e = fused_activation_function(); _o->fused_activation_function = _e; };
{ auto _e = cell_clip(); _o->cell_clip = _e; };
{ auto _e = proj_clip(); _o->proj_clip = _e; };
}

inline flatbuffers::Offset<UnidirectionalSequenceLSTMOptions> UnidirectionalSequenceLSTMOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const UnidirectionalSequenceLSTMOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
return CreateUnidirectionalSequenceLSTMOptions(_fbb, _o, _rehasher);
}

inline flatbuffers::Offset<UnidirectionalSequenceLSTMOptions> CreateUnidirectionalSequenceLSTMOptions(flatbuffers::FlatBufferBuilder &_fbb, const UnidirectionalSequenceLSTMOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
(void)_rehasher;
(void)_o;
struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const UnidirectionalSequenceLSTMOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
auto _fused_activation_function = _o->fused_activation_function;
auto _cell_clip = _o->cell_clip;
auto _proj_clip = _o->proj_clip;
return tflite::CreateUnidirectionalSequenceLSTMOptions(
_fbb,
_fused_activation_function,
_cell_clip,
_proj_clip);
}

inline BidirectionalSequenceLSTMOptionsT *BidirectionalSequenceLSTMOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new BidirectionalSequenceLSTMOptionsT();
UnPackTo(_o, _resolver);
Expand Down Expand Up @@ -9620,6 +9755,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob
auto ptr = reinterpret_cast<const BidirectionalSequenceRNNOptions *>(obj);
return verifier.VerifyTable(ptr);
}
case BuiltinOptions_UnidirectionalSequenceLSTMOptions: {
auto ptr = reinterpret_cast<const UnidirectionalSequenceLSTMOptions *>(obj);
return verifier.VerifyTable(ptr);
}
default: return false;
}
}
Expand Down Expand Up @@ -9918,6 +10057,10 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c
auto ptr = reinterpret_cast<const BidirectionalSequenceRNNOptions *>(obj);
return ptr->UnPack(resolver);
}
case BuiltinOptions_UnidirectionalSequenceLSTMOptions: {
auto ptr = reinterpret_cast<const UnidirectionalSequenceLSTMOptions *>(obj);
return ptr->UnPack(resolver);
}
default: return nullptr;
}
}
Expand Down Expand Up @@ -10204,6 +10347,10 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff
auto ptr = reinterpret_cast<const BidirectionalSequenceRNNOptionsT *>(value);
return CreateBidirectionalSequenceRNNOptions(_fbb, ptr, _rehasher).Union();
}
case BuiltinOptions_UnidirectionalSequenceLSTMOptions: {
auto ptr = reinterpret_cast<const UnidirectionalSequenceLSTMOptionsT *>(value);
return CreateUnidirectionalSequenceLSTMOptions(_fbb, ptr, _rehasher).Union();
}
default: return 0;
}
}
Expand Down Expand Up @@ -10490,6 +10637,10 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL
value = new BidirectionalSequenceRNNOptionsT(*reinterpret_cast<BidirectionalSequenceRNNOptionsT *>(u.value));
break;
}
case BuiltinOptions_UnidirectionalSequenceLSTMOptions: {
value = new UnidirectionalSequenceLSTMOptionsT(*reinterpret_cast<UnidirectionalSequenceLSTMOptionsT *>(u.value));
break;
}
default:
break;
}
Expand Down Expand Up @@ -10847,6 +10998,11 @@ inline void BuiltinOptionsUnion::Reset() {
delete ptr;
break;
}
case BuiltinOptions_UnidirectionalSequenceLSTMOptions: {
auto ptr = reinterpret_cast<UnidirectionalSequenceLSTMOptionsT *>(value);
delete ptr;
break;
}
default: break;
}
value = nullptr;
Expand Down

0 comments on commit 7fa6a6b

Please sign in to comment.