Skip to content

Commit

Permalink
improve flop estimation
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy committed Aug 11, 2020
1 parent 3ee17d9 commit 522c466
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 26 deletions.
28 changes: 18 additions & 10 deletions src/auto_scheduler/compute_dag.cc
Original file line number Diff line number Diff line change
Expand Up @@ -536,17 +536,25 @@ class FlopEstimator : public ExprFunctor<double(const PrimExpr& n)> {
double ret = 0;
for (const auto& op : ops) {
if (auto pop = op.as<te::ComputeOpNode>()) {
double num_element = AxisLengthProd(pop->axis);
if (num_element == -1) {
fail_ = true;
break;
}
cur_type_code_ = pop->output_dtype(0).code();
double op_per_element = 0;
for (const auto& x : pop->body) {
op_per_element += VisitExpr(x);
if (pop->attrs.count("FLOP")) {
// Use user-provided FLOP
auto pint = pop->attrs["FLOP"].as<IntImmNode>();
CHECK(pint != nullptr);
ret += pint->value;
} else {
// Estimate by parsing the compute body
double num_element = AxisLengthProd(pop->axis);
if (num_element == -1) {
fail_ = true;
break;
}
cur_type_code_ = pop->output_dtype(0).code();
double op_per_element = 0;
for (const auto& x : pop->body) {
op_per_element += VisitExpr(x);
}
ret += num_element * op_per_element;
}
ret += num_element * op_per_element;
} else if (op->IsInstance<te::PlaceholderOpNode>()) {
{} // do nothing
} else {
Expand Down
17 changes: 4 additions & 13 deletions src/auto_scheduler/feature.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,7 @@ enum class AnnotationPosType : int {
};

// Buffer access type
enum class BufferAccessType : int {
kRead = 0,
kWrite = 1,
kReadWrite = 2,
kUnknownRW = 3
};
enum class BufferAccessType : int { kRead = 0, kWrite = 1, kReadWrite = 2, kUnknownRW = 3 };

// Accesses to a buffer
struct BufferAccess {
Expand All @@ -89,11 +84,7 @@ struct BufferAccess {
};

// Data reuse type
enum class ReuseType : int {
kLoopMultipleRead = 0,
kSerialMultipleReadWrite = 1,
kNoReuse = 2
};
enum class ReuseType : int { kLoopMultipleRead = 0, kSerialMultipleReadWrite = 1, kNoReuse = 2 };

// Feature for an access of a buffer
struct BufferAccessFeature {
Expand Down Expand Up @@ -1514,8 +1505,8 @@ TVM_REGISTER_GLOBAL("auto_scheduler.GetPerStoreFeaturesFromFile")
std::vector<float> normalized_throughputs;
std::vector<int> task_ids;

GetPerStoreFeaturesFromFile(filename, max_lines, max_n_bufs, &features, &normalized_throughputs,
&task_ids);
GetPerStoreFeaturesFromFile(filename, max_lines, max_n_bufs, &features,
&normalized_throughputs, &task_ids);

std::vector<char> byte_data;
*ret = SerializeFeatures(std::move(features), std::move(normalized_throughputs),
Expand Down
10 changes: 7 additions & 3 deletions tests/python/unittest/test_auto_scheduler_compute_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,17 @@ def test_estimate_flop():

D = topi.nn.relu(C)
dag = auto_scheduler.ComputeDAG([A, B, D])
assert abs(dag.flop_ct - 2 * N ** 3 - N * N) < 0.5
assert abs(dag.flop_ct - (2 * N ** 3 + N * N)) < 0.5

# should not count the comparison operations in padding
D = topi.nn.pad(C, [1, 1])
dag = auto_scheduler.ComputeDAG([A, B, D])
E = topi.nn.pad(C, [1, 1])
dag = auto_scheduler.ComputeDAG([A, B, E])
assert abs(dag.flop_ct - 2 * N ** 3) < 0.5

F = te.compute((N, N), lambda i, j: E[i,j], name='F', attrs={"FLOP": 1234})
dag = auto_scheduler.ComputeDAG([A, B, F])
assert abs(dag.flop_ct - (2 * N ** 3 + 1234)) < 0.5


if __name__ == "__main__":
test_apply_steps()
Expand Down

0 comments on commit 522c466

Please sign in to comment.