diff --git a/python/tvm/relay/ty.py b/python/tvm/relay/ty.py index 088f076abb75..0f7e0e82ad4d 100644 --- a/python/tvm/relay/ty.py +++ b/python/tvm/relay/ty.py @@ -31,11 +31,17 @@ class TensorType(Type): Parameters ---------- - shape: List[tvm.Expr] + shape : List[tvm.Expr] The shape of the Tensor - dtype: str, optional + dtype : Optional[str] The content data type. + Default to "float32". + + Returns + ------- + tensor_type : tvm.relay.TensorType + The tensor type. """ def __init__(self, shape, dtype="float32"): self.__init_handle_by_constructor__( @@ -57,10 +63,10 @@ class Kind(IntEnum): @register_relay_node class TypeVar(Type): - """A type parameter used for generic types in Relay, + """A type variable used for generic types in Relay, see tvm/relay/type.h for more details. - A type parameter represents a type placeholder which will + A type variable represents a type placeholder which will be filled in later on. This allows the user to write functions which are generic over types. """ @@ -70,16 +76,17 @@ def __init__(self, var, kind=Kind.Type): Parameters ---------- - var: tvm.expr.Var + var : tvm.expr.Var The tvm.Var which backs the type parameter. - kind: Kind, optional + kind : Optional[Kind] The kind of the type parameter. + Default to Kind.Type. Returns ------- - type_param: TypeVar - The type parameter. + type_var : tvm.relay.TypeVar + The type variable. """ self.__init_handle_by_constructor__(_make.TypeVar, var, kind) @@ -102,11 +109,13 @@ def __init__(self, fields): Parameters ---------- - fields: List[tvm.relay.Type] + fields : List[tvm.relay.Type] + The fields in the tuple Returns ------- - tuple_type: the tuple type + tuple_type : tvm.relay.TupleType + the tuple type """ self.__init_handle_by_constructor__(_make.TupleType, fields) @@ -125,16 +134,16 @@ class FuncType(Type): Parameters ---------- - arg_types: List[tvm.relay.Type] + arg_types : List[tvm.relay.Type] The argument types - ret_type: tvm.relay.Type + ret_type : tvm.relay.Type The return type. - type_params: List[tvm.relay.TypeVar] + type_params : Optional[List[tvm.relay.TypeVar]] The type parameters - type_constraints: List[tvm.relay.TypeConstraint] + type_constraints : Optional[List[tvm.relay.TypeConstraint]] The type constraints. """ def __init__(self, @@ -163,18 +172,23 @@ class TypeRelation(TypeConstraint): Parameters ---------- - func: EnvFunc + func : EnvFunc User defined relation function. - args: list of types + args : [tvm.relay.Type] List of types to the func. - num_inputs: int + num_inputs : int Number of input arguments in args, this act as a hint for type inference. - attrs: Attrs + attrs : Attrs The attribute attached to the relation information + + Returns + ------- + type_relation : tvm.relay.TypeRelation + The type relation. """ def __init__(self, func, args, num_inputs, attrs): self.__init_handle_by_constructor__(_make.TypeRelation, @@ -188,12 +202,12 @@ def scalar_type(dtype): Parameters ---------- - dtype: str + dtype : str The content data type. Returns ------- - s_type: tvm.relay.TensorType + s_type : tvm.relay.TensorType The result type. """ return TensorType((), dtype)