Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TensorIR] New schedule primitive set_dtype #14316

Merged
merged 7 commits into from
Mar 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -584,13 +584,23 @@ class ScheduleNode : public runtime::Object {
virtual void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int factor,
int offset) = 0;
/*!
* \brief Set the storage scope of a buffer, where the buffer is specified by the a block and a
* \brief Set the storage scope of a buffer, where the buffer is specified by a block and a
* write-index
* \param block_rv The producer block of the buffer
* \param buffer_index The index of the buffer in block's write region
* \param storage_scope The storage scope to be set
*/
virtual void SetScope(const BlockRV& block_rv, int buffer_index, const String& storage_scope) = 0;
/*!
* \brief Set the data type of a buffer, where the buffer is specified by a block and a
* write-index
* \note This schedule primitive is unsafe and may change correctness of program because of
* type conversion, please use with caution.
* \param block_rv The producer block of the buffer
* \param buffer_index the index of the buffer in block's write region
* \param dtype The data type to be set
*/
virtual void UnsafeSetDType(const BlockRV& block_rv, int buffer_index, const String& dtype) = 0;
/******** Schedule: Blockize & Tensorize ********/
/*!
* \brief Convert the subtree rooted at a specific loop into a block.
Expand Down
79 changes: 77 additions & 2 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -2322,7 +2322,7 @@ def after_storage_align(a: T.handle, c: T.handle) -> None:
@type_checked
def set_scope(self, block: Union[BlockRV, str], buffer_index: int, storage_scope: str) -> None:
"""Set the storage scope of a buffer, where the buffer is
specified by the a block and a write-index
specified by the a block and a write-index.

Parameters
----------
Expand Down Expand Up @@ -2384,13 +2384,88 @@ def after_set_scope(

Note
----
Set_scope requires the buffer to be an intermediate buffer defined via `alloc_buffer`.
`set_scope` requires the buffer to be an intermediate buffer defined via `alloc_buffer`.
"""
block = self._normalize_block_arg(block)
_ffi_api.ScheduleSetScope( # type: ignore # pylint: disable=no-member
self, block, buffer_index, storage_scope
)

@type_checked
def unsafe_set_dtype(self, block: Union[BlockRV, str], buffer_index: int, dtype: str) -> None:
"""Set the data type of a buffer, where the buffer is
specified by the a block and write-index.

This schedule primitive is unsafe and may change the correctness of program because of
type conversion, please use with caution.

Parameters
----------
block : Union[BlockRV, str]
The producer block of the buffer
buffer_index : int
The index of the buffer in block's write region
dtype : str
The data type to be set

Examples
--------

Before set_dtype, in TensorIR, the IR is:

.. code-block:: python

@T.prim_func
def before_set_dtype(
A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")
) -> None:
B = T.alloc_buffer((128, 128), dtype="float32")

for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj] * 2.0
for i, j in T.grid(128, 128):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j]
C[vi, vj] = B[vi, vj] + 1.0

Create the schedule and do set_dtype:

.. code-block:: python

sch = tir.Schedule(before_set_dtype)
sch.set_dtype("B", buffer_index=0, dtype="float16")
print(sch.mod["main"].script())

After applying set_dtype, the IR becomes:

.. code-block:: python

@T.prim_func
def after_set_dtype(
A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")
) -> None:
B = T.alloc_buffer((128, 128), dtype="float16")

for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = T.cast(A[vi, vj] * 2.0, "float16")
for i, j in T.grid(128, 128):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j]
C[vi, vj] = T.cast(B[vi, vj], "float32") + 1.0

Note
----
`set_dtype` requires the buffer to be an intermediate buffer defined via `alloc_buffer`.
"""
block = self._normalize_block_arg(block)
_ffi_api.ScheduleUnsafeSetDType( # type: ignore # pylint: disable=no-member
self, block, buffer_index, dtype
)

########## Schedule: Blockize & Tensorize ##########

@type_checked
Expand Down
8 changes: 8 additions & 0 deletions src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,14 @@ void ConcreteScheduleNode::SetScope(const BlockRV& block_rv, int buffer_index,
this->state_->DebugVerify();
}

void ConcreteScheduleNode::UnsafeSetDType(const BlockRV& block_rv, int buffer_index,
const String& dtype) {
TVM_TIR_SCHEDULE_BEGIN();
tir::UnsafeSetDType(state_, this->GetSRef(block_rv), buffer_index, dtype);
TVM_TIR_SCHEDULE_END("set-dtype", this->error_render_level_);
this->state_->DebugVerify();
}

/******** Schedule: Reduction ********/

BlockRV ConcreteScheduleNode::DecomposeReduction(const BlockRV& block_rv, const LoopRV& loop_rv) {
Expand Down
1 change: 1 addition & 0 deletions src/tir/schedule/concrete_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ class ConcreteScheduleNode : public ScheduleNode {
void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int factor,
int offset) override;
void SetScope(const BlockRV& block_rv, int buffer_index, const String& storage_scope) override;
void UnsafeSetDType(const BlockRV& block_rv, int buffer_index, const String& dtype) override;
/******** Schedule: Blockize & Tensorize ********/
BlockRV Blockize(const LoopRV& loop_rv, bool preserve_unit_iters) override;
void Tensorize(const BlockRV& block_rv, const String& intrin, bool preserve_unit_iters) override;
Expand Down
12 changes: 12 additions & 0 deletions src/tir/schedule/primitive.h
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,18 @@ TVM_DLL void StorageAlign(ScheduleState self, const StmtSRef& block_sref, int bu
*/
TVM_DLL void SetScope(ScheduleState self, const StmtSRef& block_sref, int buffer_index,
const String& storage_scope);
/*!
* \brief Set the data type of a buffer, where the buffer is specified by a block and a
* write-index
* \note This schedule primitive is unsafe and may change correctness of program because of
* type conversion, please use with caution.
* \param self The state of the schedule
* \param block_sref The sref of the producer block of the buffer
* \param buffer_index The index of the buffer in block's write region
* \param dtype The data type to be set
*/
TVM_DLL void UnsafeSetDType(ScheduleState self, const StmtSRef& block_sref, int buffer_index,
const String& dtype);
/*!
* \brief Set the axis separator of a buffer, where the buffer is specified by a block and a read
* or write index
Expand Down
117 changes: 117 additions & 0 deletions src/tir/schedule/primitive/block_annotate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
* specific language governing permissions and limitations
* under the License.
*/
#include <tvm/tir/expr.h>

#include "../utils.h"

namespace tvm {
Expand Down Expand Up @@ -297,6 +299,93 @@ void SetScope(ScheduleState self, const StmtSRef& block_sref, int buffer_index,
self->Replace(alloc_site_sref, new_block, block_reuse_map);
}

/*!
* \brief A helper mutator which recursively mutates the old buffer's data type, inserts data type
* conversions, and collecte the block sref reuse information for the following replacement.
*/
class DTypeMutator : private ReplaceBufferMutator {
public:
/*!
* \param allocate_site The block where `old_buffer` was allocated.
* \param old_buffer The old buffer
* \param target_dtype The data type to be set
* \param block_sref_reuse The block sref reuse map to be updated
* \return The new block after the mutation
*/
static Block Mutate(const Block& allocate_site, const Buffer& old_buffer, const DataType& dtype,
Map<Block, Block>* block_sref_reuse) {
Buffer new_buffer = WithDType(old_buffer, dtype);
DTypeMutator mutator(old_buffer, new_buffer, dtype, block_sref_reuse);
Stmt new_block = mutator.VisitStmt(allocate_site);
return Downcast<Block>(new_block);
}

private:
DTypeMutator(const Buffer& old_buffer, Buffer new_buffer, const DataType& dtype,
Map<Block, Block>* block_sref_reuse)
: ReplaceBufferMutator(old_buffer, std::move(new_buffer), block_sref_reuse),
src_dtype_(old_buffer->dtype),
tgt_dtype_(dtype) {}

MatchBufferRegion VisitMatchBufferRegion(const MatchBufferRegion& match_buffer) final {
auto it = buffer_var_map_.find(match_buffer->source->buffer->data.get());
if (it != buffer_var_map_.end()) {
Buffer new_target_buffer = WithDType(match_buffer->buffer, it->second->dtype);
buffer_var_map_[match_buffer->buffer->data.get()] = new_target_buffer;
return MatchBufferRegion(new_target_buffer,
BufferRegion(it->second, match_buffer->source->region));
} else {
return match_buffer;
}
}

Stmt VisitStmt_(const BufferStoreNode* op) final {
BufferStore node = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
auto it = buffer_var_map_.find(node->buffer->data.get());
if (it != buffer_var_map_.end()) {
node.CopyOnWrite()->buffer = it->second;
node.CopyOnWrite()->value = Cast(tgt_dtype_, node->value);
}
return node;
}

PrimExpr VisitExpr_(const BufferLoadNode* op) final {
BufferLoad node = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
auto it = buffer_var_map_.find(node->buffer->data.get());
if (it != buffer_var_map_.end()) {
return Cast(src_dtype_, BufferLoad(it->second, node->indices));
}
return node;
}

DataType src_dtype_, tgt_dtype_;
};

void UnsafeSetDType(ScheduleState self, const StmtSRef& block_sref, int buffer_index,
const String& dtype) {
const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
Buffer buffer =
GetNthAccessBuffer(self, GetRef<Block>(block), buffer_index, BufferIndexType::kWrite);
DataType target_dtype(runtime::String2DLDataType(dtype));

// Step 1. If `dtype` equals the original data type, just return.
if (buffer->dtype == target_dtype) {
return;
}

// Step 2. Get the allocation site of the target buffer.
StmtSRef alloc_site_sref =
NonAllocatedBufferError::CheckAndGetBufferAllocationSite(self->mod, block_sref, buffer);
const BlockNode* alloc_site = TVM_SREF_TO_BLOCK(alloc_site_sref);

// Step 3. Recursively replace old buffer to a new buffer, where the new buffer has the given
// dtype, and insert data type conversions.
Map<Block, Block> block_reuse_map;
Block new_block =
DTypeMutator::Mutate(GetRef<Block>(alloc_site), buffer, target_dtype, &block_reuse_map);
self->Replace(alloc_site_sref, new_block, block_reuse_map);
}

/******** InstructionKind Registration ********/

struct StorageAlignTraits : public UnpackedInstTraits<StorageAlignTraits> {
Expand Down Expand Up @@ -356,8 +445,36 @@ struct SetScopeTraits : public UnpackedInstTraits<SetScopeTraits> {
friend struct ::tvm::tir::UnpackedInstTraits;
};

struct UnsafeSetDTypeTraits : public UnpackedInstTraits<UnsafeSetDTypeTraits> {
static constexpr const char* kName = "UnsafeSetDType";
static constexpr bool kIsPure = false;

private:
static constexpr size_t kNumInputs = 1;
static constexpr size_t kNumAttrs = 2;
static constexpr size_t kNumDecisions = 0;

static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, Integer buffer_index,
String dtype) {
return sch->UnsafeSetDType(block_rv, buffer_index->value, dtype);
}

static String UnpackedAsPython(Array<String> outputs, String block_rv, Integer buffer_index,
String dtype) {
PythonAPICall py("unsafe_set_dtype");
py.Input("block", block_rv);
py.Input("buffer_index", buffer_index);
py.Input("dtype", dtype);
return py.Str();
}

template <typename>
friend struct ::tvm::tir::UnpackedInstTraits;
};

TVM_REGISTER_INST_KIND_TRAITS(StorageAlignTraits);
TVM_REGISTER_INST_KIND_TRAITS(SetScopeTraits);
TVM_REGISTER_INST_KIND_TRAITS(UnsafeSetDTypeTraits);

} // namespace tir
} // namespace tvm
2 changes: 2 additions & 0 deletions src/tir/schedule/schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,8 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleStorageAlign")
.set_body_method<Schedule>(&ScheduleNode::StorageAlign);
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSetScope")
.set_body_method<Schedule>(&ScheduleNode::SetScope);
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleUnsafeSetDType")
.set_body_method<Schedule>(&ScheduleNode::UnsafeSetDType);
/******** (FFI) Blockize & Tensorize ********/
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleBlockize")
.set_body_method<Schedule>(&ScheduleNode::Blockize);
Expand Down
11 changes: 11 additions & 0 deletions src/tir/schedule/traced_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,17 @@ void TracedScheduleNode::SetScope(const BlockRV& block_rv, int buffer_index,
/*outputs=*/{}));
}

void TracedScheduleNode::UnsafeSetDType(const BlockRV& block_rv, int buffer_index,
const String& dtype) {
ConcreteScheduleNode::UnsafeSetDType(block_rv, buffer_index, dtype);
static const InstructionKind& kind = InstructionKind::Get("UnsafeSetDType");
trace_->Append(/*inst=*/Instruction(
/*kind=*/kind,
/*inputs=*/{block_rv},
/*attrs=*/{Integer(buffer_index), dtype},
/*outputs=*/{}));
}

/******** Schedule: Blockize & Tensorize ********/

BlockRV TracedScheduleNode::Blockize(const LoopRV& loop_rv, bool preserve_unit_iters) {
Expand Down
1 change: 1 addition & 0 deletions src/tir/schedule/traced_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ class TracedScheduleNode : public ConcreteScheduleNode {
void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int factor,
int offset) final;
void SetScope(const BlockRV& block_rv, int buffer_index, const String& storage_scope) final;
void UnsafeSetDType(const BlockRV& block_rv, int buffer_index, const String& dtype) final;
/******** Schedule: Blockize & Tensorize ********/
BlockRV Blockize(const LoopRV& loop_rv, bool preserve_unit_iters) final;
void Tensorize(const BlockRV& block_rv, const String& intrin, bool preserve_unit_iters) final;
Expand Down
10 changes: 10 additions & 0 deletions src/tir/schedule/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,16 @@ Buffer WithScope(const Buffer& buffer, const String& scope) {
return Buffer(new_buffer);
}

Buffer WithDType(const Buffer& buffer, const DataType& dtype) {
ObjectPtr<BufferNode> new_buffer = make_object<BufferNode>(*buffer.get());
new_buffer->dtype = dtype;
const auto* ptr_type = TVM_TYPE_AS(buffer->data->type_annotation, PointerTypeNode);
new_buffer->data =
Var(buffer->data->name_hint, PointerType(PrimType(dtype), ptr_type->storage_scope));
new_buffer->name = buffer->name;
return Buffer(new_buffer);
}

Array<BufferRegion> ReplaceBuffer(Array<BufferRegion> regions, const Buffer& source,
const Buffer& target) {
regions.MutateByApply([&source, &target](BufferRegion region) -> BufferRegion {
Expand Down
12 changes: 10 additions & 2 deletions src/tir/schedule/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,14 @@ Block WithAnnotation(const BlockNode* block, const String& attr_key, const Objec
*/
Buffer WithScope(const Buffer& buffer, const String& scope);

/*!
* \brief Create a new buffer by changint the data type.
* \param buffer The given buffer.
* \param scope The target data type.
* \return The new buffer with target data type.
*/
Buffer WithDType(const Buffer& buffer, const DataType& dtype);

/*!
* \brief Replaces the buffer within the specific sequence of regions
* \param regions The regions whose buffers are to be replaced
Expand Down Expand Up @@ -131,9 +139,9 @@ class ReplaceBufferMutator : public StmtExprMutator {
return node;
}

Stmt VisitStmt_(const BufferStoreNode* op) final;
Stmt VisitStmt_(const BufferStoreNode* op) override;

PrimExpr VisitExpr_(const BufferLoadNode* op) final;
PrimExpr VisitExpr_(const BufferLoadNode* op) override;

virtual MatchBufferRegion VisitMatchBufferRegion(const MatchBufferRegion& match_buffer);

Expand Down
Loading