diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h index 34935aec61b2..86f49872c67e 100644 --- a/include/tvm/tir/op.h +++ b/include/tvm/tir/op.h @@ -796,6 +796,15 @@ TVM_DLL PrimExpr min(PrimExpr source, Array axis, Array TVM_DLL PrimExpr prod(PrimExpr source, Array axis, Array init = {}, Span span = Span()); +/*! + * \brief Calculate fmod(x, y) + * \param x Left operand. + * \param y Right operand. + * \param span The location of this operation in the source. + * \return The result expression. + */ +TVM_DLL PrimExpr fmod(PrimExpr x, PrimExpr y, Span span = Span()); + /*! * \brief Calculate floor(x) * \param x The input expression. @@ -896,6 +905,7 @@ TVM_DECLARE_INTRIN_UNARY(rsqrt); TVM_DECLARE_INTRIN_UNARY(log); TVM_DECLARE_INTRIN_UNARY(log2); TVM_DECLARE_INTRIN_UNARY(log10); +TVM_DECLARE_INTRIN_UNARY(log1p); TVM_DECLARE_INTRIN_UNARY(popcount); TVM_DECLARE_INTRIN_UNARY(tan); TVM_DECLARE_INTRIN_UNARY(cos); diff --git a/python/tvm/script/builder/_ffi_api.py b/python/tvm/script/builder/_ffi_api.py index 3410494ded4d..98d8618ad7f1 100644 --- a/python/tvm/script/builder/_ffi_api.py +++ b/python/tvm/script/builder/_ffi_api.py @@ -17,4 +17,4 @@ """FFI APIs for tvm.script.builder""" import tvm._ffi -tvm._ffi._init_api("script.builder", __name__) # pylint: disable=protected-access +tvm._ffi._init_api("script.builder", __name__) # pylint: disable=protected-access diff --git a/python/tvm/script/builder/tir/__init__.py b/python/tvm/script/builder/tir/__init__.py index e6e431e5bced..4aa0309b8741 100644 --- a/python/tvm/script/builder/tir/__init__.py +++ b/python/tvm/script/builder/tir/__init__.py @@ -31,3 +31,4 @@ ) from .prim_func_frame import arg, prim_func from .var import Buffer +from .op import * diff --git a/python/tvm/script/builder/tir/_ffi_api.py b/python/tvm/script/builder/tir/_ffi_api.py index 4e40e7261fd3..c0f5204f22ed 100644 --- a/python/tvm/script/builder/tir/_ffi_api.py +++ b/python/tvm/script/builder/tir/_ffi_api.py @@ -17,4 +17,4 @@ """FFI APIs for tvm.script.builder.tir""" import tvm._ffi -tvm._ffi._init_api("script.builder.tir", __name__) # pylint: disable=protected-access +tvm._ffi._init_api("script.builder.tir", __name__) # pylint: disable=protected-access diff --git a/python/tvm/script/builder/tir/op.py b/python/tvm/script/builder/tir/op.py new file mode 100644 index 000000000000..d70f1f0a2920 --- /dev/null +++ b/python/tvm/script/builder/tir/op.py @@ -0,0 +1,161 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""TVM Script TIR Op""" + +from . import _ffi_api + + +from tvm.tir.op import abs, popcount, nextafter, copysign, fmod +from tvm.tir.op import ( + floor, + floordiv, + floormod, + ceil, + round, + trunc, + truncdiv, + truncmod, + nearbyint, +) +from tvm.tir.op import ( + hypot, + ldexp, + power, + exp, + exp2, + exp10, + erf, + sqrt, + rsqrt, + log, + log2, + log10, + log1p, + sigmoid, +) +from tvm.tir.op import isnan, isfinite, isinf +from tvm.tir.op import cos, cosh, sin, sinh, tan, tanh +from tvm.tir.op import acos, acosh, asin, asinh, atan, atanh +from tvm.tir.op import atan2, clz, comm_reducer, infinity, reinterpret +from tvm.tir.op import min_value, max_value, if_then_else +from tvm.tir.op import call_packed, call_extern +from tvm.tir.expr import Select, Ramp, Broadcast, Shuffle +from tvm.tir.generic import cast + + +def boolean(expr): + return _ffi_api.PrimType("bool", expr) + + +def int8(expr): + return _ffi_api.PrimType("int8", expr) + + +def int16(expr): + return _ffi_api.PrimType("int16", expr) + + +def int32(expr): + return _ffi_api.PrimType("int32", expr) + + +def int64(expr): + return _ffi_api.PrimType("int64", expr) + + +def uint8(expr): + return _ffi_api.PrimType("uint8", expr) + + +def uint16(expr): + return _ffi_api.PrimType("uint16", expr) + + +def uint32(expr): + return _ffi_api.PrimType("uint32", expr) + + +def uint64(expr): + return _ffi_api.PrimType("uint64", expr) + + +def float8(expr): + return _ffi_api.PrimType("float8", expr) + + +def float16(expr): + return _ffi_api.PrimType("float16", expr) + + +def float32(expr): + return _ffi_api.PrimType("float32", expr) + + +def float64(expr): + return _ffi_api.PrimType("float64", expr) + + +def min(a, b, span=None): + """Compute the minimum value of two expressions. + + Parameters + ---------- + a : PrimExpr + The left hand operand + + b : PrimExpr + The right hand operand + + span : Optional[Span] + The location of this operator in the source. + + Returns + ------- + res : PrimExpr + The result expression. + + Note + ---- + This is the default integer division behavior in C. + """ + return _ffi_api.min(a, b, span) # type: ignore + + +def max(a, b, span=None): + """Compute the maximum value of two expressions. + + Parameters + ---------- + a : PrimExpr + The left hand operand + + b : PrimExpr + The right hand operand + + span : Optional[Span] + The location of this operator in the source. + + Returns + ------- + res : PrimExpr + The result expression. + + Note + ---- + This is the default integer division behavior in C. + """ + return _ffi_api.max(a, b, span) # type: ignore diff --git a/python/tvm/script/builder/tir/var.py b/python/tvm/script/builder/tir/var.py index 18a8ecd59bbe..4c4163cb941a 100644 --- a/python/tvm/script/builder/tir/var.py +++ b/python/tvm/script/builder/tir/var.py @@ -20,10 +20,12 @@ from . import _ffi_api -def Buffer( # pylint: disable=invalid-name +def Buffer( # pylint: disable=invalid-name shape, dtype, name="buffer", storage_scope="", ) -> tir.Buffer: - return _ffi_api.Buffer(shape, dtype, name, storage_scope) # pylint: disable=no-member # type: ignore + return _ffi_api.Buffer( + shape, dtype, name, storage_scope + ) # pylint: disable=no-member # type: ignore diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index 6db93b6ad091..173f4b8c4dbe 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -44,17 +44,36 @@ from .function import PrimFunc, TensorIntrin, IndexMap -from .op import call_packed, call_cpacked, call_intrin, call_pure_extern, call_extern -from .op import call_llvm_intrin, call_llvm_pure_intrin, ret, all, any, min_value, max_value, trace +from .op import call_packed, call_intrin, call_pure_extern, call_extern +from .op import ( + call_llvm_intrin, + call_llvm_pure_intrin, + ret, + all, + any, + min_value, + max_value, + trace, +) from .op import exp, exp2, exp10, log, log2, log10, log1p, ldexp, clz from .op import sin, sinh, asin, asinh from .op import cos, cosh, acos, acosh from .op import tan, tanh, atan, atan2, atanh from .op import erf, sigmoid, sqrt, rsqrt, floor, ceil, hypot -from .op import trunc, abs, round, nextafter, nearbyint, power, popcount, fmod, if_then_else +from .op import ( + trunc, + abs, + round, + nextafter, + nearbyint, + power, + popcount, + fmod, + if_then_else, +) from .op import isnan, isfinite, isinf, copysign from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod -from .op import comm_reducer, min, max, sum +from .op import comm_reducer, min, max, sum, infinity, reinterpret from .op import q_multiply_shift from .schedule import StmtSRef, BlockScope, ScheduleState, Schedule, ScheduleError diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index 5d15bf15da58..74c1437c6c5c 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -151,7 +151,10 @@ def call_pure_extern(dtype, func_name, *args, span=None): The call expression. """ return Call( - dtype, Op.get("tir.call_pure_extern"), convert((StringImm(func_name),) + args), span + dtype, + Op.get("tir.call_pure_extern"), + convert((StringImm(func_name),) + args), + span, ) @@ -178,7 +181,10 @@ def call_extern(dtype, func_name, *args, span=None): The call expression. """ return Call( - dtype, Op.get("tir.call_extern"), convert((StringImm(func_name),) + args), span=span + dtype, + Op.get("tir.call_extern"), + convert((StringImm(func_name),) + args), + span=span, ) @@ -210,7 +216,11 @@ def call_llvm_intrin(dtype, name, *args, span=None): llvm_id = codegen.llvm_lookup_intrinsic_id(name) assert llvm_id != 0, "%s is not an LLVM intrinsic" % name return call_intrin( - dtype, Op.get("tir.call_llvm_intrin"), tvm.tir.const(llvm_id, "uint32"), *args, span=span + dtype, + Op.get("tir.call_llvm_intrin"), + tvm.tir.const(llvm_id, "uint32"), + *args, + span=span, ) @@ -394,6 +404,47 @@ def max_value(dtype: str, span: Optional[Span] = None) -> Any: return _ffi_api.max_value(dtype, span) # type: ignore +def infinity(dtype: str, span: Optional[Span] = None) -> Any: + """infinity value of dtype + + Parameters + ---------- + dtype : str + The data type. + + span : Optional[Span] + The location of this operator in the source code. + + Returns + ------- + value : tvm.Expr + The infinity value of dtype. + """ + return _ffi_api.infinity(dtype, span) # type: ignore + + +def reinterpret(dtype, value, span=None) -> Any: + """infinity value of dtype + + Parameters + ---------- + dtype : str + The data type. + + value : PrimExpr + The input value. + + span : Optional[Span] + The location of this operator in the source code. + + Returns + ------- + value : tvm.Expr + The reinterpret cast value of dtype. + """ + return _ffi_api.reinterpret(dtype, value, span) # type: ignore + + def exp(x): """Take exponential of input x. diff --git a/src/script/builder/tir/op.cc b/src/script/builder/tir/op.cc new file mode 100644 index 000000000000..777ac8a4a407 --- /dev/null +++ b/src/script/builder/tir/op.cc @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "./op.h" + +namespace tvm { +namespace script { +namespace builder { +namespace tir { + +PrimExpr prim_type(String type_name, PrimExpr expr) { + return cast(DataType(runtime::String2DLDataType(type_name)), expr); +} + +TVM_REGISTER_GLOBAL("script.builder.tir.PrimType").set_body_typed(prim_type); +TVM_REGISTER_GLOBAL("script.builder.tir.min").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { + return tvm::min(a, b, span); +}); +TVM_REGISTER_GLOBAL("script.builder.tir.max").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { + return tvm::max(a, b, span); +}); + +} // namespace tir +} // namespace builder +} // namespace script +} // namespace tvm diff --git a/src/script/builder/tir/op.h b/src/script/builder/tir/op.h new file mode 100644 index 000000000000..9f7a668f8330 --- /dev/null +++ b/src/script/builder/tir/op.h @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SCRIPT_BUILDER_TIR_OP_H_ +#define TVM_SCRIPT_BUILDER_TIR_OP_H_ + +#include +#include + +#include "../builder.h" + +namespace tvm { +namespace script { +namespace builder { +namespace tir { + +PrimExpr int8(PrimExpr expr) { return cast(DataType::Int(8), expr); } +PrimExpr int16(PrimExpr expr) { return cast(DataType::Int(16), expr); } +PrimExpr int32(PrimExpr expr) { return cast(DataType::Int(32), expr); } +PrimExpr int64(PrimExpr expr) { return cast(DataType::Int(64), expr); } + +PrimExpr uint8(PrimExpr expr) { return cast(DataType::UInt(8), expr); } +PrimExpr uint16(PrimExpr expr) { return cast(DataType::UInt(16), expr); } +PrimExpr uint32(PrimExpr expr) { return cast(DataType::UInt(32), expr); } +PrimExpr uint64(PrimExpr expr) { return cast(DataType::UInt(64), expr); } + +PrimExpr float8(PrimExpr expr) { return cast(DataType::Float(8), expr); } +PrimExpr float16(PrimExpr expr) { return cast(DataType::Float(16), expr); } +PrimExpr float32(PrimExpr expr) { return cast(DataType::Float(32), expr); } +PrimExpr float64(PrimExpr expr) { return cast(DataType::Float(64), expr); } + +PrimExpr bool_(PrimExpr expr) { return cast(DataType::Bool(), expr); } + +PrimExpr prim_type(String type_name, PrimExpr expr); + +using tvm::cast; +using tvm::if_then_else; +using tvm::infinity; +using tvm::max; +using tvm::max_value; +using tvm::min; +using tvm::min_value; +using tvm::reinterpret; + +using tvm::ceil; +using tvm::floor; +using tvm::floordiv; +using tvm::floormod; +using tvm::nearbyint; +using tvm::round; +using tvm::trunc; +using tvm::truncdiv; +using tvm::truncmod; + +using tvm::abs; +using tvm::copysign; +using tvm::fmod; +using tvm::nextafter; +using tvm::popcount; + +using tvm::erf; +using tvm::exp; +using tvm::exp10; +using tvm::exp2; +using tvm::hypot; +using tvm::ldexp; +using tvm::log; +using tvm::log10; +using tvm::log1p; +using tvm::log2; +using tvm::pow; +using tvm::rsqrt; +using tvm::sigmoid; +using tvm::sqrt; + +using tvm::acos; +using tvm::acosh; +using tvm::asin; +using tvm::asinh; +using tvm::atan; +using tvm::atan2; +using tvm::atanh; +using tvm::clz; +using tvm::cos; +using tvm::cosh; +using tvm::sin; +using tvm::sinh; +using tvm::tan; +using tvm::tanh; + +using tvm::isfinite; +using tvm::isinf; +using tvm::isnan; + +using tvm::tir::Broadcast; +using tvm::tir::CommReducer; +using tvm::tir::Ramp; +using tvm::tir::Select; +using tvm::tir::Shuffle; + +} // namespace tir +} // namespace builder +} // namespace script +} // namespace tvm + +#endif // TVM_SCRIPT_BUILDER_TIR_OP_H_ diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index 73249921bf3b..3a5c39fda467 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -947,6 +947,10 @@ TVM_REGISTER_GLOBAL("tir.trunc").set_body_typed(tvm::trunc); TVM_REGISTER_GLOBAL("tir._cast").set_body_typed(tvm::cast); +TVM_REGISTER_GLOBAL("tir.infinity").set_body_typed(tvm::infinity); + +TVM_REGISTER_GLOBAL("tir.reinterpret").set_body_typed(tvm::reinterpret); + // operator overloading, smarter than make #define REGISTER_MAKE_BINARY_OP(Node, Func) \ TVM_REGISTER_GLOBAL("tir." #Node).set_body_typed([](PrimExpr a, PrimExpr b, Span span) { \