Skip to content

Commit

Permalink
Add FLOP annotation functions to operator schema
Browse files Browse the repository at this point in the history
Summary: Basic FLOP annotation functionality added to operator schema.

Reviewed By: dzhulgakov

Differential Revision: D5114086

fbshipit-source-id: 8a15d45dee744fbdceaed3773d70fb69a5cf0d24
  • Loading branch information
bwasti authored and facebook-github-bot committed May 30, 2017
1 parent acb2ad1 commit 0deec5b
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 1 deletion.
5 changes: 5 additions & 0 deletions caffe2/core/operator_schema.cc
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,11 @@ OpSchema& OpSchema::ScalarType(::caffe2::TensorProto_DataType dt) {
});
}

OpSchema& OpSchema::CostInferenceFunction(CostInferenceFunctionType function) {
cost_inference_function_ = function;
return *this;
}

OpSchema& OpSchema::SetDoc(const string& doc) {
doc_ = doc;
return *this;
Expand Down
33 changes: 33 additions & 0 deletions caffe2/core/operator_schema.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <vector>

#include "caffe2/core/common.h"
#include "caffe2/core/logging.h"
#include "caffe2/core/registry.h"
#include "caffe2/proto/caffe2.pb.h"

Expand Down Expand Up @@ -159,6 +160,33 @@ class OpSchema {
return tensor_inference_function_(def, input_type_shape);
}

/*
* @brief A struct to store various cost information about
* an operator such as FLOPs and total memory use.
*/
struct Cost {
size_t flops; // Floating point operations.
size_t bytes_moved; // Total memory used.
};
/**
* @brief Registers a function that takes in an OperatorDef
* and a series of input shapes and returns the total "cost"
* required to run the operator via struct by value.
*/
typedef std::function<
struct Cost(const OperatorDef&, const vector<TensorShape>&)>
CostInferenceFunctionType;

/**
* @brief Register the Cost inference function.
*/
OpSchema& CostInferenceFunction(CostInferenceFunctionType function);
inline struct Cost InferCost(
const OperatorDef& def,
const vector<TensorShape>& input_tensor_shape) const {
return cost_inference_function_(def, input_tensor_shape);
}

// Functions to do documentation for the operator schema.
OpSchema& SetDoc(const string& doc);
OpSchema& Arg(const char* name, const char* description);
Expand Down Expand Up @@ -226,6 +254,11 @@ class OpSchema {
}
return out;
};
CostInferenceFunctionType cost_inference_function_ =
[](const OperatorDef& def, const vector<TensorShape>&) {
CAFFE_THROW("No cost inference function registered.");
return Cost();
};
};

/**
Expand Down
31 changes: 30 additions & 1 deletion caffe2/core/operator_schema_test.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "caffe2/core/operator_schema.h"
#include "caffe2/core/logging.h"
#include "caffe2/core/operator.h"
#include "caffe2/core/operator_schema.h"
#include "caffe2/utils/proto_utils.h"

#include <gtest/gtest.h>
Expand Down Expand Up @@ -208,4 +209,32 @@ TEST(OperatorSchemaTest, TestCastSchema) {
EXPECT_EQ(out[0].dims_size(), 0);
}

OPERATOR_SCHEMA(OpSchemaCostInference)
.NumInputs(2)
.NumOutputs(2)
.CostInferenceFunction(
[](const OperatorDef& def, const vector<TensorShape>& inputs) {
struct OpSchema::Cost c;
c.flops =
2 * inputs[0].dims(0) * inputs[0].dims(1) * inputs[1].dims(1);
return c;
});

TEST(OperatorSchemaTest, TestCostInference) {
const OpSchema* schema = OpSchemaRegistry::Schema("OpSchemaCostInference");
if (!schema) {
return;
}
OperatorDef def = CreateOperatorDef(
"OpSchemaCostInference", "", vector<string>{"in"}, vector<string>{"out"});
vector<TensorShape> shapes(2);
shapes[0].set_data_type(TensorProto::FLOAT);
shapes[0].add_dims(10);
shapes[0].add_dims(10);
shapes[1].set_data_type(TensorProto::FLOAT);
shapes[1].add_dims(10);
shapes[1].add_dims(10);
EXPECT_EQ(2000, schema->InferCost(def, shapes).flops);
}

} // namespace caffe2

0 comments on commit 0deec5b

Please sign in to comment.