Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] Remove LoadNode and StoreNode #14381

Merged
merged 1 commit into from
Mar 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 0 additions & 60 deletions include/tvm/tir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -730,66 +730,6 @@ class ProducerLoad : public PrimExpr {
TVM_DEFINE_OBJECT_REF_COW_METHOD(ProducerLoadNode);
};

/*!
* \brief Load the value from buffer_var.
*
* Equivalent to ((DType*)buffer_var)[index]
* where DType is the type specified by type().element_of().
*
* For example, if type = float32x3, then the load will corresponds to
*
* \code
*
* auto buffer = static_cast<float*>(buffer_var);
* auto loaded_val = float32x3(buffer[index.v0], buffer[index.v1], buffer[index.v2]);
*
* \endcode
*/
class LoadNode : public PrimExprNode {
public:
/*! \brief The buffer variable. */
Var buffer_var;
/*! \brief The index locations to be loaded. */
PrimExpr index;
/*! \brief The predicate to mask which lanes would be loaded. */
PrimExpr predicate;

void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &dtype);
v->Visit("buffer_var", &buffer_var);
v->Visit("index", &index);
v->Visit("predicate", &predicate);
v->Visit("span", &span);
}

bool SEqualReduce(const LoadNode* other, SEqualReducer equal) const {
return equal(dtype, other->dtype) && equal(buffer_var, other->buffer_var) &&
equal(index, other->index) && equal(predicate, other->predicate);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(dtype);
hash_reduce(buffer_var);
hash_reduce(index);
hash_reduce(predicate);
}

static constexpr const char* _type_key = "tir.Load";
TVM_DECLARE_FINAL_OBJECT_INFO(LoadNode, PrimExprNode);
};

/*!
* \brief Managed reference to LoadNode
* \sa LoadNode
*/
class Load : public PrimExpr {
public:
TVM_DLL Load(DataType dtype, Var buffer_var, PrimExpr index, PrimExpr predicate,
Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(Load, PrimExpr, LoadNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(LoadNode);
};

/*!
* \brief Construct a vector with lanes elements
* where its i-th element equals base + i * stride.
Expand Down
4 changes: 0 additions & 4 deletions include/tvm/tir/expr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,6 @@ class ExprFunctor<R(const PrimExpr& n, Args...)> {
}
virtual R VisitExpr_(const BufferLoadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const ProducerLoadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const LoadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const LetNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const CallNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const AddNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
Expand Down Expand Up @@ -162,7 +161,6 @@ class ExprFunctor<R(const PrimExpr& n, Args...)> {
// Set dispatch
IR_EXPR_FUNCTOR_DISPATCH(VarNode);
IR_EXPR_FUNCTOR_DISPATCH(SizeVarNode);
IR_EXPR_FUNCTOR_DISPATCH(LoadNode);
IR_EXPR_FUNCTOR_DISPATCH(BufferLoadNode);
IR_EXPR_FUNCTOR_DISPATCH(ProducerLoadNode);
IR_EXPR_FUNCTOR_DISPATCH(LetNode);
Expand Down Expand Up @@ -214,7 +212,6 @@ class TVM_DLL ExprVisitor : public ExprFunctor<void(const PrimExpr&)> {
// list of functions to override.
void VisitExpr_(const VarNode* op) override;
void VisitExpr_(const SizeVarNode* op) override;
void VisitExpr_(const LoadNode* op) override;
void VisitExpr_(const BufferLoadNode* op) override;
void VisitExpr_(const ProducerLoadNode* op) override;
void VisitExpr_(const LetNode* op) override;
Expand Down Expand Up @@ -261,7 +258,6 @@ class TVM_DLL ExprMutator : protected ExprFunctor<PrimExpr(const PrimExpr&)> {
// list of functions to override.
PrimExpr VisitExpr_(const VarNode* op) override;
PrimExpr VisitExpr_(const SizeVarNode* op) override;
PrimExpr VisitExpr_(const LoadNode* op) override;
PrimExpr VisitExpr_(const BufferLoadNode* op) override;
PrimExpr VisitExpr_(const ProducerLoadNode* op) override;
PrimExpr VisitExpr_(const LetNode* op) override;
Expand Down
66 changes: 0 additions & 66 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -213,72 +213,6 @@ class AssertStmt : public Stmt {
TVM_DEFINE_OBJECT_REF_COW_METHOD(AssertStmtNode);
};

/*!
* \brief Store value to the buffer.
*
* Equivalent to ((DType*)buffer_var)[index] = value.
* where DType is the type specified by type().element_of().
*
* For example, if type = float32x3, then the store will corresponds to
*
* \code
*
* auto buffer = static_cast<float*>(buffer_var);
* buffer[index.v0] = value.v0;
* buffer[index.v1] = value.v1;
* buffer[index.v2] = value.v2;
*
* \endcode
* \sa LoadNode
*/
class StoreNode : public StmtNode {
public:
/*! \brief The buffer variable. */
Var buffer_var;
/*! \brief The value to be stored. */
PrimExpr value;
/*! \brief The index locations to be stored. */
PrimExpr index;
/*! \brief The predicate to mask which lanes would be stored. */
PrimExpr predicate;

void VisitAttrs(AttrVisitor* v) {
v->Visit("buffer_var", &buffer_var);
v->Visit("value", &value);
v->Visit("index", &index);
v->Visit("predicate", &predicate);
v->Visit("span", &span);
}

bool SEqualReduce(const StoreNode* other, SEqualReducer equal) const {
return equal(buffer_var, other->buffer_var) && equal(value, other->value) &&
equal(index, other->index) && equal(predicate, other->predicate);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(buffer_var);
hash_reduce(value);
hash_reduce(index);
hash_reduce(predicate);
}

static constexpr const char* _type_key = "tir.Store";
TVM_DECLARE_FINAL_OBJECT_INFO(StoreNode, StmtNode);
};

/*!
* \brief Managed reference to StoreNode.
* \sa StoreNode
*/
class Store : public Stmt {
public:
TVM_DLL Store(Var buffer_var, PrimExpr value, PrimExpr index, PrimExpr predicate,
Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(Store, Stmt, StoreNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(StoreNode);
};

/*!
* \brief Store value to the high dimension buffer.
*
Expand Down
4 changes: 0 additions & 4 deletions include/tvm/tir/stmt_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
virtual R VisitStmt_(const AllocateNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const AllocateConstNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const DeclBufferNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const StoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const BufferStoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const BufferRealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const AssertStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
Expand All @@ -117,7 +116,6 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
IR_STMT_FUNCTOR_DISPATCH(AllocateNode);
IR_STMT_FUNCTOR_DISPATCH(AllocateConstNode);
IR_STMT_FUNCTOR_DISPATCH(DeclBufferNode);
IR_STMT_FUNCTOR_DISPATCH(StoreNode);
IR_STMT_FUNCTOR_DISPATCH(AssertStmtNode);
IR_STMT_FUNCTOR_DISPATCH(ProducerStoreNode);
IR_STMT_FUNCTOR_DISPATCH(ProducerRealizeNode);
Expand Down Expand Up @@ -161,7 +159,6 @@ class TVM_DLL StmtVisitor : protected StmtFunctor<void(const Stmt&)> {
void VisitStmt_(const AllocateNode* op) override;
void VisitStmt_(const AllocateConstNode* op) override;
void VisitStmt_(const DeclBufferNode* op) override;
void VisitStmt_(const StoreNode* op) override;
void VisitStmt_(const BufferStoreNode* op) override;
void VisitStmt_(const BufferRealizeNode* op) override;
void VisitStmt_(const AssertStmtNode* op) override;
Expand Down Expand Up @@ -263,7 +260,6 @@ class TVM_DLL StmtMutator : protected StmtFunctor<Stmt(const Stmt&)> {
Stmt VisitStmt_(const AllocateNode* op) override;
Stmt VisitStmt_(const AllocateConstNode* op) override;
Stmt VisitStmt_(const DeclBufferNode* op) override;
Stmt VisitStmt_(const StoreNode* op) override;
Stmt VisitStmt_(const BufferStoreNode* op) override;
Stmt VisitStmt_(const BufferRealizeNode* op) override;
Stmt VisitStmt_(const AssertStmtNode* op) override;
Expand Down
2 changes: 0 additions & 2 deletions python/tvm/ir/json_compact.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,6 @@ def _convert(item, nodes):
"Or": _rename("tir.Or"),
"Not": _rename("tir.Not"),
"Select": _rename("tir.Select"),
"Load": _rename("tir.Load"),
"BufferLoad": _rename("tir.BufferLoad"),
"Ramp": _rename("tir.Ramp"),
"Broadcast": _rename("tir.Broadcast"),
Expand All @@ -221,7 +220,6 @@ def _convert(item, nodes):
"Any": _rename("tir.Any"),
"LetStmt": _rename("tir.LetStmt"),
"AssertStmt": _rename("tir.AssertStmt"),
"Store": _rename("tir.Store"),
"BufferStore": _rename("tir.BufferStore"),
"BufferRealize": _rename("tir.BufferRealize"),
"Allocate": _rename("tir.Allocate"),
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/backend/contrib/ethosu/tir/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def _remove_zero_store(stmt):

def _ftransform(f, mod, ctx):
return f.with_body(
tvm.tir.stmt_functor.ir_transform(f.body, _remove_zero_store, None, ["tir.Store"])
tvm.tir.stmt_functor.ir_transform(f.body, _remove_zero_store, None, ["tir.BufferStore"])
)

return tvm.tir.transform.prim_func_pass(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,6 @@ def assign_addresses(buffer_info, npu_ops, scratch_region_map):
The key is the buffer name to BufferInfo
npu_ops : list
A list of Vela NpuOps with tir.BufferLoads for addresses
A list of Vela NpuOps with tir.Loads for addresses
scratch_region_map : Dict[tvm.tir.Var, RegionOffset]
A buffer_var to region and offset map.
Returns
Expand Down
2 changes: 0 additions & 2 deletions python/tvm/script/ir_builder/tir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@
FloorMod,
IntImm,
IterVar,
Load,
Max,
Min,
Mod,
Expand Down Expand Up @@ -2124,7 +2123,6 @@ def wrapped(*args, **kwargs):
"Select",
"BufferLoad",
"ProducerLoad",
"Load",
"Ramp",
"Broadcast",
"Shuffle",
Expand Down
3 changes: 1 addition & 2 deletions python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,13 @@
from .expr import Var, SizeVar, Reduce, FloatImm, IntImm, StringImm, Cast
from .expr import Add, Sub, Mul, Div, Mod, FloorDiv, FloorMod
from .expr import Min, Max, EQ, NE, LT, LE, GT, GE, And, Or, Not
from .expr import Select, BufferLoad, ProducerLoad, Load, Ramp, Broadcast, Shuffle
from .expr import Select, BufferLoad, ProducerLoad, Ramp, Broadcast, Shuffle
from .expr import Call, CallEffectKind, Let, IterVar, CommReducer, Any

from .stmt import Stmt, LetStmt, AssertStmt, ForKind, For, While
from .stmt import (
BufferStore,
BufferRealize,
Store,
ProducerStore,
Allocate,
AllocateConst,
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/tir/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,8 @@ def calculate_allocated_bytes(func: PrimFunc) -> Dict[str, int]:

def detect_buffer_access_lca(func: PrimFunc) -> Dict[Buffer, Stmt]:
"""Detect the lowest common ancestor(LCA) of buffer access, including both high-level
access(BufferLoad, BufferStore) and low-level access(Load, Store and opaque access).
access (BufferLoad, BufferStore) and low-level access (BufferLoad, BufferStore and opaque
access).
The LCA may be a For loop or a Block.

Parameters
Expand Down
30 changes: 0 additions & 30 deletions python/tvm/tir/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1013,36 +1013,6 @@ def __init__(self, condition, true_value, false_value, span=None):
)


@tvm._ffi.register_object("tir.Load")
class Load(PrimExprWithOp):
"""Load node.

Parameters
----------
dtype : str
The data type.

buffer_var : Var
The buffer variable in the load expression.

index : PrimExpr
The index in the load.

predicate : PrimExpr
The load predicate.

span : Optional[Span]
The location of this itervar in the source code.
"""

def __init__(self, dtype, buffer_var, index, predicate=None, span=None):
if predicate is None:
predicate = _ffi_api.const_true(dtype, span) # type: ignore
self.__init_handle_by_constructor__(
_ffi_api.Load, dtype, buffer_var, index, predicate, span # type: ignore
)


@tvm._ffi.register_object("tir.BufferLoad")
class BufferLoad(PrimExprWithOp):
"""Buffer load node.
Expand Down
38 changes: 4 additions & 34 deletions python/tvm/tir/stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
.. code-block:: python

x = tvm.tir.Var("n", "int32")
a = tvm.tir.Var("array", "handle")
st = tvm.tir.stmt.Store(a, x + 1, 1)
assert isinstance(st, tvm.tir.stmt.Store)
assert(st.buffer_var == a)
buffer = tvm.tir.decl_buffer((16,), "float32")
st = tvm.tir.stmt.BufferStore(buffer, 1, (x,))
assert isinstance(st, tvm.tir.stmt.BufferStore)
assert(st.buffer == buffer)
"""
from enum import IntEnum
from typing import List, Mapping, Optional, Union
Expand Down Expand Up @@ -189,36 +189,6 @@ def __init__(self, condition, body, span=None):
)


@tvm._ffi.register_object("tir.Store")
class Store(Stmt):
"""Store node.

Parameters
----------
buffer_var : Var
The buffer Variable.

value : PrimExpr
The value we want to store.

index : PrimExpr
The index in the store expression.

predicate : PrimExpr
The store predicate.

span : Optional[Span]
The location of this itervar in the source code.
"""

def __init__(self, buffer_var, value, index, predicate=None, span=None):
if predicate is None:
predicate = _ffi_api.const_true(value.dtype, span) # type: ignore
self.__init_handle_by_constructor__(
_ffi_api.Store, buffer_var, value, index, predicate, span # type: ignore
)


@tvm._ffi.register_object("tir.BufferStore")
class BufferStore(Stmt):
"""Buffer store node.
Expand Down
6 changes: 0 additions & 6 deletions src/contrib/hybrid/codegen_hybrid.cc
Original file line number Diff line number Diff line change
Expand Up @@ -250,12 +250,6 @@ void CodeGenHybrid::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLIN
}
}

void CodeGenHybrid::VisitExpr_(const LoadNode* op, std::ostream& os) { // NOLINT(*)
LOG(FATAL) << "Phase 0 has no Load(s)!";
}

void CodeGenHybrid::VisitStmt_(const StoreNode* op) { LOG(FATAL) << "Phase 0 has no Store(s)!"; }

void CodeGenHybrid::VisitExpr_(const BufferLoadNode* op, std::ostream& os) { // NOLINT(*)
LOG(FATAL) << "Phase 0 has no BufferLoad(s)!";
}
Expand Down
2 changes: 0 additions & 2 deletions src/contrib/hybrid/codegen_hybrid.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ class CodeGenHybrid : public ExprFunctor<void(const PrimExpr&, std::ostream&)>,
}
// expression
void VisitExpr_(const VarNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const LoadNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const BufferLoadNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const LetNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const CallNode* op, std::ostream& os) override; // NOLINT(*)
Expand Down Expand Up @@ -121,7 +120,6 @@ class CodeGenHybrid : public ExprFunctor<void(const PrimExpr&, std::ostream&)>,
void VisitExpr_(const StringImmNode* op, std::ostream& os) override; // NOLINT(*)
// statment
void VisitStmt_(const LetStmtNode* op) override;
void VisitStmt_(const StoreNode* op) override;
void VisitStmt_(const BufferStoreNode* op) override;
void VisitStmt_(const ProducerStoreNode* op) override;
void VisitStmt_(const ForNode* op) override;
Expand Down
Loading