Skip to content

Commit

Permalink
Fix cooperative fetching (#17)
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao authored Jan 22, 2022
1 parent e41b5b2 commit 3e4a30e
Show file tree
Hide file tree
Showing 9 changed files with 193 additions and 129 deletions.
4 changes: 2 additions & 2 deletions include/tvm/meta_schedule/schedule_rule.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ class ScheduleRule : public runtime::ObjectRef {
* - [blockIdx.x, vthread.x, threadIdx.x] on GPU
* \param use_tensor_core Whether to apply tensor core wmma intrinsic for the computation
* \param max_innermost_factor The maximum size of the innermost factor. NullOpt means no limit
* \param vector_load_max_len The length of vector lane in vectorized cooperative fetching.
* \param vector_load_lens The length of vector lane in vectorized cooperative fetching.
* NullOpt means disable vectorization
* \param reuse_read Data reuse configuration for reading. NullOpt means no reuse.
* \param reuse_write Data reuse configuration for writing. NullOpt means no reuse.
Expand All @@ -151,7 +151,7 @@ class ScheduleRule : public runtime::ObjectRef {
Optional<Array<String>> tile_binds, //
bool use_tensor_core, //
Optional<Integer> max_innermost_factor, //
Optional<Integer> vector_load_max_len, //
Optional<Array<Integer>> vector_load_lens, //
Optional<Map<String, ObjectRef>> reuse_read, //
Optional<Map<String, ObjectRef>> reuse_write);
/*!
Expand Down
6 changes: 3 additions & 3 deletions python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class MultiLevelTiling(ScheduleRule):
Whether to apply tensor core wmma intrinsic for the computation
max_innermost_factor : Optional[int]
The maximum size of the innermost factor. None means no limit
vector_load_max_len : Optional[int]
vector_load_lens : Optional[List[int]]
The length of vector lane in vectorized cooperative fetching.
None means disable vectorization
reuse_read : Optional[ReuseType]
Expand All @@ -72,7 +72,7 @@ def __init__(
tile_binds: Optional[List[str]] = None,
use_tensor_core: bool = False,
max_innermost_factor: Optional[int] = None,
vector_load_max_len: Optional[int] = None,
vector_load_lens: Optional[List[int]] = None,
reuse_read: Optional[ReuseType] = None,
reuse_write: Optional[ReuseType] = None,
) -> None:
Expand All @@ -82,7 +82,7 @@ def __init__(
tile_binds,
use_tensor_core,
max_innermost_factor,
vector_load_max_len,
vector_load_lens,
reuse_read.as_dict() if reuse_read is not None else None,
reuse_write.as_dict() if reuse_write is not None else None,
)
6 changes: 3 additions & 3 deletions python/tvm/meta_schedule/testing/schedule_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def multi_level_tiling(target: Target) -> ScheduleRule:
structure="SSRSRS",
tile_binds=None,
max_innermost_factor=64,
vector_load_max_len=None,
vector_load_lens=None,
reuse_read=None,
reuse_write=ReuseType(
req="may",
Expand All @@ -124,7 +124,7 @@ def multi_level_tiling(target: Target) -> ScheduleRule:
structure="SSSRRSRS",
tile_binds=["blockIdx.x", "vthread.x", "threadIdx.x"],
max_innermost_factor=64,
vector_load_max_len=4,
vector_load_lens=[1, 2, 3, 4],
reuse_read=ReuseType(
req="must",
levels=[4],
Expand All @@ -147,7 +147,7 @@ def multi_level_tiling_tensor_core(target: Target) -> ScheduleRule:
tile_binds=["blockIdx.x", "blockIdx.y", "threadIdx.y"],
use_tensor_core=True,
max_innermost_factor=64,
vector_load_max_len=4,
vector_load_lens=[1, 2, 3, 4],
reuse_read=ReuseType(
req="must",
levels=[4],
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/meta_schedule/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def _sch_rules() -> List[ScheduleRule]:
structure="SSRSRS",
tile_binds=None,
max_innermost_factor=64,
vector_load_max_len=None,
vector_load_lens=None,
reuse_read=None,
reuse_write=M.ReuseType(
req="may",
Expand Down Expand Up @@ -158,7 +158,7 @@ def _sch_rules() -> List[ScheduleRule]:
structure="SSSRRSRS",
tile_binds=["blockIdx.x", "vthread.x", "threadIdx.x"],
max_innermost_factor=64,
vector_load_max_len=4,
vector_load_lens=[1, 2, 3, 4],
reuse_read=M.ReuseType(
req="must",
levels=[4],
Expand Down
108 changes: 84 additions & 24 deletions src/meta_schedule/mutator/mutate_tile_size.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ using tir::Trace;
* \param decision The decision of Sample-Perfect-Tile
* \return The result of downcast
*/
std::vector<int64_t> DowncastDecision(const ObjectRef& decision) {
std::vector<int64_t> DowncastTilingDecision(const ObjectRef& decision) {
const auto* arr = TVM_TYPE_AS(arr, decision, runtime::ArrayNode);
return support::AsVector<ObjectRef, int64_t>(GetRef<Array<ObjectRef>>(arr));
}
Expand Down Expand Up @@ -73,34 +73,62 @@ class MutateTileSizeNode : public MutatorNode {
* \param decision The decision selected
* \return Whether a decision is found
*/
bool FindSamplePerfectTile(const Trace& trace, TRandState* rand_state, Instruction* inst,
std::vector<int64_t>* decision) {
void FindSamplePerfectTile(const Trace& trace, std::vector<Instruction>* inst,
std::vector<std::vector<int64_t>>* decision) {
static const InstructionKind& inst_sample_perfect_tile =
InstructionKind::Get("SamplePerfectTile");
std::vector<Instruction> instructions;
std::vector<std::vector<int64_t>> decisions;
std::vector<Instruction>& instructions = *inst;
std::vector<std::vector<int64_t>>& decisions = *decision;
instructions.reserve(trace->decisions.size());
decisions.reserve(trace->decisions.size());
for (const auto& kv : trace->decisions) {
const Instruction& inst = kv.first;
const ObjectRef& decision = kv.second;
if (!inst->kind.same_as(inst_sample_perfect_tile)) {
continue;
if (inst->kind.same_as(inst_sample_perfect_tile)) {
std::vector<int64_t> tiles = DowncastTilingDecision(decision);
if (tiles.size() >= 2 && Product(tiles) >= 2) {
instructions.push_back(inst);
decisions.push_back(tiles);
}
}
std::vector<int64_t> tiles = DowncastDecision(decision);
if (tiles.size() >= 2 && Product(tiles) >= 2) {
instructions.push_back(inst);
decisions.push_back(tiles);
}
}

void FindSampleVectorize(const Trace& trace, std::vector<Instruction>* inst,
std::vector<int64_t>* decision) {
static const InstructionKind& inst_sample_categorical = InstructionKind::Get("SampleCategorical");
static const InstructionKind& inst_annotate = InstructionKind::Get("Annotate");
std::vector<Instruction>& instructions = *inst;
std::vector<int64_t>& decisions = *decision;
std::unordered_set<const Object*> annotated;
instructions.reserve(trace->decisions.size());
decisions.reserve(trace->decisions.size());
annotated.reserve(trace->decisions.size());
// Find annotation with `meta_schedule_cooperative_fetch`
for (const Instruction& inst : trace->insts) {
if (inst->kind.same_as(inst_annotate)) {
ICHECK_EQ(inst->attrs.size(), 1);
ICHECK_EQ(inst->inputs.size(), 2);
if (Downcast<String>(inst->attrs[0]) == tir::attr::meta_schedule_cooperative_fetch) {
const auto* ann_val = inst->inputs[1].as<tir::ExprRVNode>();
ICHECK(ann_val);
annotated.insert(ann_val);
}
}
}
int n = instructions.size();
if (n > 0) {
int i = tir::SampleInt(rand_state, 0, n);
*inst = instructions[i];
*decision = decisions[i];
return true;
// Find sampling instruction that generates the annotation
for (const auto& kv : trace->decisions) {
const Instruction& inst = kv.first;
const ObjectRef& decision = kv.second;
if (inst->kind.same_as(inst_sample_categorical)) {
ICHECK_EQ(inst->outputs.size(), 1);
if (annotated.count(inst->outputs[0].get())) {
const auto* d = TVM_TYPE_AS(d, decision, IntImmNode);
instructions.push_back(inst);
decisions.push_back(d->value);
}
}
}
return false;
}

struct FactorMemo {
Expand Down Expand Up @@ -146,12 +174,8 @@ struct FactorMemo {
std::mutex mutex_;
};

Optional<Trace> MutateTileSizeNode::Apply(const Trace& trace, TRandState* rand_state) {
Instruction inst;
std::vector<int64_t> tiles;
if (!FindSamplePerfectTile(trace, rand_state, &inst, &tiles)) {
return NullOpt;
}
Optional<Trace> MutateSampleTileSize(const Trace& trace, Instruction inst,
std::vector<int64_t> tiles, TRandState* rand_state) {
int n_splits = tiles.size();
// Step 1. Choose two loops, `x` and `y`
int x, y;
Expand Down Expand Up @@ -194,6 +218,42 @@ Optional<Trace> MutateTileSizeNode::Apply(const Trace& trace, TRandState* rand_s
}
}

Optional<Trace> MutateSampleVectorize(const Trace& trace, Instruction inst,
int64_t original_decision, TRandState* rand_state) {
ICHECK_EQ(inst->attrs.size(), 2);
std::vector<double> probs =
support::AsVector<FloatImm, double>(Downcast<Array<FloatImm>>(inst->attrs[1]));
probs.erase(probs.begin() + original_decision);
int result = tir::MakeMultinomialSampler(rand_state, probs)();
if (result >= original_decision) {
result += 1;
}
return trace->WithDecision(inst, Integer(result), /*remove_postproc=*/true);
}

Optional<Trace> MutateTileSizeNode::Apply(const Trace& trace, TRandState* rand_state) {
std::vector<Instruction> sample_perfect_tile_insts;
std::vector<Instruction> sample_vectorize_insts;
std::vector<std::vector<int64_t>> sample_perfect_tile_tiles;
std::vector<int64_t> sample_vectorize_decisions;
FindSamplePerfectTile(trace, &sample_perfect_tile_insts, &sample_perfect_tile_tiles);
FindSampleVectorize(trace, &sample_vectorize_insts, &sample_vectorize_decisions);
int size_a = sample_perfect_tile_insts.size();
int size_b = sample_vectorize_insts.size();
if (size_a == 0 && size_b == 0) {
return NullOpt;
}
int n = tir::SampleInt(rand_state, 0, size_a + size_b);
if (n < size_a) {
return MutateSampleTileSize(trace, sample_perfect_tile_insts[n], sample_perfect_tile_tiles[n],
rand_state);
} else {
n -= size_a;
return MutateSampleVectorize(trace, sample_vectorize_insts[n], sample_vectorize_decisions[n],
rand_state);
}
}

Mutator Mutator::MutateTileSize() { return Mutator(make_object<MutateTileSizeNode>()); }

TVM_REGISTER_NODE_TYPE(MutateTileSizeNode);
Expand Down
24 changes: 14 additions & 10 deletions src/meta_schedule/schedule_rule/multi_level_tiling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ class MultiLevelTilingNode : public ScheduleRuleNode {
/*! \brief The maximum size of the innermost factor */
int max_innermost_factor;
/*! \brief The length of vector lane in vectorized cooperative fetching */
int vector_load_max_len;
std::vector<int> vector_load_lens;
/*! \brief Data reuse configuration for reading */
ReuseConfig reuse_read_;
/*! \brief Data reuse configuration for writing */
Expand All @@ -337,7 +337,7 @@ class MultiLevelTilingNode : public ScheduleRuleNode {
v->Visit("tile_binds", &tile_binds);
v->Visit("use_tensor_core", &use_tensor_core);
v->Visit("max_innermost_factor", &max_innermost_factor);
v->Visit("vector_load_max_len", &vector_load_max_len);
// `vector_load_lens` is not visited
// `reuse_read_` is not visited
// `reuse_write_` is not visited
// `s_indices_` is not visited
Expand Down Expand Up @@ -491,12 +491,14 @@ inline std::vector<State> MultiLevelTilingNode::AddReadReuse(State state) const
LoopRV fused = sch->Fuse(Array<LoopRV>{buffer_loops.end() - buffer_ndim, //
buffer_loops.end()});
// Annotate cooperative fetching
if (vector_load_max_len > 0) {
// cooperative fetch + vectorized loading
// Split into inner and outer, vectorize the inner loop
Array<ExprRV> factors = sch->SamplePerfectTile(fused, 2, vector_load_max_len);
// Add cooperative fetching
sch->Annotate(cache_read_block, tir::attr::meta_schedule_cooperative_fetch, factors[1]);
if (!vector_load_lens.empty()) {
int n = vector_load_lens.size();
double prob = 1.0 / n;
ExprRV vector_load_len =
sch->SampleCategorical(support::AsArray<int, Integer>(vector_load_lens),
Array<FloatImm>(n, FloatImm(DataType::Float(64), prob)));
sch->Annotate(cache_read_block, tir::attr::meta_schedule_cooperative_fetch,
vector_load_len);
}
}
State new_state = state;
Expand Down Expand Up @@ -545,7 +547,7 @@ inline std::vector<State> MultiLevelTilingNode::FuseWriteReuse(State state) cons
ScheduleRule ScheduleRule::MultiLevelTiling(String structure, Optional<Array<String>> tile_binds,
bool use_tensor_core,
Optional<Integer> max_innermost_factor,
Optional<Integer> vector_load_max_len,
Optional<Array<Integer>> vector_load_lens,
Optional<Map<String, ObjectRef>> reuse_read,
Optional<Map<String, ObjectRef>> reuse_write) {
ObjectPtr<MultiLevelTilingNode> n = make_object<MultiLevelTilingNode>();
Expand All @@ -561,7 +563,9 @@ ScheduleRule ScheduleRule::MultiLevelTiling(String structure, Optional<Array<Str
tir::TensorIntrin::Get("wmma_fill");
}
n->max_innermost_factor = max_innermost_factor.value_or(Integer(-1))->value;
n->vector_load_max_len = vector_load_max_len.value_or(Integer(-1))->value;
n->vector_load_lens = vector_load_lens.defined()
? support::AsVector<Integer, int>(vector_load_lens.value())
: std::vector<int>();
n->reuse_read_ = reuse_read.defined() ? ReuseConfig(reuse_read.value()) : ReuseConfig();
n->reuse_write_ = reuse_write.defined() ? ReuseConfig(reuse_write.value()) : ReuseConfig();
for (int i = 0, len = structure.size(); i < len; ++i) {
Expand Down
Loading

0 comments on commit 3e4a30e

Please sign in to comment.