diff --git a/python/tvm/relay/op/dyn/nn/_nn.py b/python/tvm/relay/op/dyn/nn/_nn.py index a263561006c8..0cbc07e680d6 100644 --- a/python/tvm/relay/op/dyn/nn/_nn.py +++ b/python/tvm/relay/op/dyn/nn/_nn.py @@ -38,7 +38,21 @@ def compute_upsampling(attrs, inputs, out_dtype): return [topi.nn.upsampling(data, scale_h, scale_w, layout, method, align_corners, out_dtype.shape)] +# upsampling3d +@register_compute("dyn.nn.upsampling3d") +def compute_upsampling3d(attrs, inputs, out_dtype): + data = inputs[0] + scale_d = inputs[1] + scale_h = inputs[2] + scale_w = inputs[3] + layout = attrs.layout + method = attrs.method + coordinate_transformation_mode = attrs.coordinate_transformation_mode + return [topi.nn.upsampling3d(data, scale_d, scale_h, scale_w, layout, method,\ + coordinate_transformation_mode, out_dtype.shape)] + register_injective_schedule("dyn.nn.upsampling") +register_injective_schedule("dyn.nn.upsampling3d") register_broadcast_schedule("dyn.nn.pad") ##################### @@ -47,12 +61,12 @@ def compute_upsampling(attrs, inputs, out_dtype): # upsampling @script -def _upsampling_shape_func(dshape, scale_h, scale_w, height_axis, width_axis, channel_axis): +def _upsampling_shape_func(dshape, scale_h, scale_w, height_axis, width_axis): out = output_tensor((4,), "int64") - out[0] = int64(dshape[0]) + for i in const_range(4): + out[i] = int64(dshape[i]) out[height_axis] = int64(round(dshape[height_axis] * scale_h[0])) out[width_axis] = int64(round(dshape[width_axis] * scale_w[0])) - out[channel_axis] = int64(dshape[channel_axis]) return out @register_shape_func("dyn.nn.upsampling", True) @@ -65,11 +79,39 @@ def upsampling_shape_func(attrs, inputs, _): height_axis = i if letter == "W": width_axis = i - if letter == "C": - channel_axis = i return [_upsampling_shape_func(inputs[0].shape, inputs[1], inputs[2], - convert(height_axis), convert(width_axis), - convert(channel_axis))] + convert(height_axis), convert(width_axis))] + +# upsampling3d +@script +def _upsampling3d_shape_func(dshape, scale_d, scale_h, scale_w, + depth_axis, height_axis, width_axis): + out = output_tensor((5,), "int64") + for i in const_range(5): + out[i] = int64(dshape[i]) + out[depth_axis] = int64(round(dshape[depth_axis] * scale_d[0])) + out[height_axis] = int64(round(dshape[height_axis] * scale_h[0])) + out[width_axis] = int64(round(dshape[width_axis] * scale_w[0])) + return out + + +@register_shape_func("dyn.nn.upsampling3d", True) +def upsampling3d_shape_func(attrs, inputs, _): + """Shape function for upsampling. Supports NCHW and NHWC layouts.""" + layout = attrs.layout + depth_axis = height_axis = width_axis = 1 + for i, letter in enumerate(layout): + if letter == "D": + depth_axis = i + if letter == "H": + height_axis = i + if letter == "W": + width_axis = i + return [_upsampling3d_shape_func(inputs[0].shape, inputs[1], inputs[2], + inputs[3], convert(depth_axis), + convert(height_axis), + convert(width_axis))] + # pad @script def _dyn_pad_shape_func(data, pad_width): diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index 6f6849a79a9a..587dbc72e929 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -1229,6 +1229,15 @@ def upsampling3d(data, result : tvm.relay.Expr The computed result. """ + if isinstance(scale_d, Expr) or isinstance(scale_h, Expr) or isinstance(scale_w, Expr): + if not isinstance(scale_d, Expr): + scale_d = const(scale_d, "float64") + if not isinstance(scale_h, Expr): + scale_h = const(scale_h, "float64") + if not isinstance(scale_w, Expr): + scale_w = const(scale_w, "float64") + return _dyn_make.upsampling3d(data, scale_d, scale_h, scale_w, layout, method, + coordinate_transformation_mode) return _make.upsampling3d(data, scale_d, scale_h, scale_w, layout, method, coordinate_transformation_mode) diff --git a/python/tvm/topi/image/resize.py b/python/tvm/topi/image/resize.py index 25258925aa37..b159723c3ffe 100644 --- a/python/tvm/topi/image/resize.py +++ b/python/tvm/topi/image/resize.py @@ -519,8 +519,9 @@ def resize(data, size, layout="NCHW", method="bilinear", out_dtype: string, optional Type to return. If left None will be same as input type. - output_shape: optional + output_shape: tvm.tir.container.Array, optional Shape to return. If left None will be inferred + (If shape is determined dynamically, pass out_dtype.shape as output_shape) Returns ------- @@ -680,17 +681,22 @@ def resize3d(data, size, layout="NCDHW", method="nearest_neighbor", inputs is a 5-D tensor with shape [batch, channel, in_depth, in_height, in_width] or [batch, in_depth, in_height, in_width, channel] + size: Tuple Output resolution scale to + layout: string, optional "NCDHW", "NDHWC", or "NCDHWc". + coordinate_transformation_mode: string, optional Describes how to transform the coordinate in the resized tensor to the coordinate in the original tensor. Refer to the ONNX Resize operator specification for details. + Available options are "half_pixel", "align_corners" and "asymmetric". method: {"trilinear", "nearest_neighbor"} Method to be used for resizing. + out_dtype: string, optional Type to return. If left None will be same as input type. diff --git a/python/tvm/topi/nn/upsampling.py b/python/tvm/topi/nn/upsampling.py index db4af06e0694..6c07cf4e513c 100644 --- a/python/tvm/topi/nn/upsampling.py +++ b/python/tvm/topi/nn/upsampling.py @@ -44,6 +44,10 @@ def upsampling(data, scale_h, scale_w, layout="NCHW", method='nearest_neighbor', method : {"bilinear", "nearest_neighbor", "bicubic"} Method to be used for upsampling. + output_shape: tvm.tir.container.Array, optional + Shape to return. If left None will be inferred + (If shape is determined dynamically, pass out_dtype.shape as output_shape) + Returns ------- output : tvm.te.Tensor @@ -79,7 +83,7 @@ def upsampling(data, scale_h, scale_w, layout="NCHW", method='nearest_neighbor', def upsampling3d(data, scale_d, scale_h, scale_w, layout="NCDHW", method='nearest_neighbor', - coordinate_transformation_mode="half_pixel"): + coordinate_transformation_mode="half_pixel", output_shape=None): """Perform upsampling on the data. Nearest neighbor and bilinear upsampling are supported. @@ -111,6 +115,10 @@ def upsampling3d(data, scale_d, scale_h, scale_w, layout="NCDHW", method='neares Refer to the ONNX Resize operator specification for details. Available options are "half_pixel", "align_corners" and "asymmetric". + output_shape: tvm.tir.container.Array, optional + Shape to return. If left None will be inferred + (If shape is determined dynamically, pass out_dtype.shape as output_shape) + Returns ------- output : tvm.te.Tensor @@ -119,15 +127,30 @@ def upsampling3d(data, scale_d, scale_h, scale_w, layout="NCDHW", method='neares """ base_layout = layout[0:5] if base_layout == "NCDHW": - out_shape = (simplify(topi.cast(te.round(data.shape[2] * scale_d), data.shape[2].dtype)), - simplify(topi.cast(te.round(data.shape[3] * scale_h), data.shape[3].dtype)), - simplify(topi.cast(te.round(data.shape[4] * scale_w), data.shape[4].dtype))) + if not output_shape: # static case + scaled_d = data.shape[2] * scale_d + scaled_h = data.shape[3] * scale_h + scaled_w = data.shape[4] * scale_w + resize_shape = (simplify(topi.cast(te.round(scaled_d), data.shape[2].dtype)), + simplify(topi.cast(te.round(scaled_h), data.shape[3].dtype)), + simplify(topi.cast(te.round(scaled_w), data.shape[4].dtype))) + else: # dynamic case -- don't need to scale; already done in shape func + resize_shape = (simplify(topi.cast(te.round(output_shape[2]), data.shape[2].dtype)), + simplify(topi.cast(te.round(output_shape[3]), data.shape[3].dtype)), + simplify(topi.cast(te.round(output_shape[4]), data.shape[4].dtype))) elif layout == "NDHWC": - out_shape = (simplify(topi.cast(te.round(data.shape[1] * scale_d), data.shape[1].dtype)), - simplify(topi.cast(te.round(data.shape[2] * scale_h), data.shape[2].dtype)), - simplify(topi.cast(te.round(data.shape[3] * scale_w), data.shape[3].dtype))) - + if not output_shape: # static case + scaled_d = data.shape[1] * scale_d + scaled_h = data.shape[2] * scale_h + scaled_w = data.shape[3] * scale_w + resize_shape = (simplify(topi.cast(te.round(scaled_d), data.shape[1].dtype)), + simplify(topi.cast(te.round(scaled_h), data.shape[2].dtype)), + simplify(topi.cast(te.round(scaled_w), data.shape[3].dtype))) + else: # dynamic case + resize_shape = (simplify(topi.cast(te.round(output_shape[1]), data.shape[1].dtype)), + simplify(topi.cast(te.round(output_shape[2]), data.shape[2].dtype)), + simplify(topi.cast(te.round(output_shape[3]), data.shape[3].dtype))) else: raise ValueError("not support this layout {} yet".format(layout)) - return topi.image.resize3d(data, out_shape, layout=layout, method=method, + return topi.image.resize3d(data, resize_shape, layout=layout, method=method, coordinate_transformation_mode=coordinate_transformation_mode) diff --git a/src/relay/op/dyn/nn/upsampling.cc b/src/relay/op/dyn/nn/upsampling.cc index e2718481ac8c..9ed3298142af 100644 --- a/src/relay/op/dyn/nn/upsampling.cc +++ b/src/relay/op/dyn/nn/upsampling.cc @@ -95,16 +95,16 @@ RELAY_REGISTER_OP("dyn.nn.upsampling") (batch_size, channels, in_height, in_width) for NCHW (batch_size, in_height, in_width, channels) for NHWC -- **scale_h**: scale_h is an integer of the amount to scale height by +- **scale_h**: scale_h is a double of the amount to scale height by -- **scale_w**: scale_w is an integer of the amount to scale width by +- **scale_w**: scale_w is a double of the amount to scale width by - **out**: Output is 4D array of shape for layout NCHW - (batch_size, channels, in_height*scale, in_width*scale) + (batch_size, channels, in_height*scale_h, in_width*scale_w) for layout NHWC - (batch_size, in_height*scale, in_width*scale, channels) + (batch_size, in_height*scale_h, in_width*scale_w, channels) )code" TVM_ADD_FILELINE) .set_attrs_type() @@ -118,6 +118,84 @@ RELAY_REGISTER_OP("dyn.nn.upsampling") UpsamplingInferCorrectLayout) .set_attr("TOpPattern", kInjective); +// UpSampling3D +bool UpSampling3DRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + // types = [data_type, scale_d_type, scale_h_type, scale_w_type, ret_type] + CHECK_EQ(types.size(), 5); + const auto* data = types[0].as(); + if (data == nullptr) return false; + + static const Layout kNCDHW("NCDHW"); + + const UpSampling3DAttrs* param = attrs.as(); + CHECK(param != nullptr); + const Layout in_layout(param->layout); + + auto layout_converter = tir::BijectiveLayout(in_layout, kNCDHW); + CHECK(layout_converter.defined()) + << "UpSampling3D only support input layouts that are convertible from NCDHW." + << " But got " << in_layout; + + auto ncdhw_oshape = layout_converter.ForwardShape(data->shape); + + ncdhw_oshape.Set(2, Any()); + ncdhw_oshape.Set(3, Any()); + ncdhw_oshape.Set(4, Any()); + + auto oshape = layout_converter.BackwardShape(ncdhw_oshape); + + reporter->Assign(types[4], TensorType(oshape, data->dtype)); + return true; +} + +Expr MakeUpSampling3D(Expr data, Expr scale_d, Expr scale_h, Expr scale_w, String layout, + String method, String coordinate_transformation_mode) { + auto attrs = make_object(); + attrs->layout = std::move(layout); + attrs->method = std::move(method); + attrs->coordinate_transformation_mode = coordinate_transformation_mode; + + static const Op& op = Op::Get("dyn.nn.upsampling3d"); + return Call(op, {data, scale_d, scale_h, scale_w}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op.dyn.nn._make.upsampling3d").set_body_typed(MakeUpSampling3D); + +RELAY_REGISTER_OP("dyn.nn.upsampling3d") + .describe(R"code(Perform upsampling on input array with nearest neighbour or +bilinear interpolation. + +- **data**: data is 5D array of shape + (batch_size, channels, in_depth, in_height, in_width) for NCDHW + (batch_size, in_depth, in_height, in_width, channels) for NDHWC + +- **scale_d**: scale_d is a double of the amount to scale depth by + +- **scale_h**: scale_h is a double of the amount to scale height by + +- **scale_w**: scale_w is a double of the amount to scale width by + +- **out**: Output is 5D array of shape + for layout NCDHW + (batch_size, channels, in_depth*scale_d, in_height*scale_h, in_width*scale_w) + + for layout NDHWC + (batch_size, in_depth*scale_d, in_height*scale_h, in_width*scale_w, channels) + +)code" TVM_ADD_FILELINE) + .set_attrs_type() + .set_num_inputs(4) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("scale_d", "double", "The scale for the depth.") + .add_argument("scale_h", "double", "The scale for the height.") + .add_argument("scale_w", "double", "The scale for the width.") + .set_support_level(2) + .add_type_rel("DynamicUpSampling3D", UpSampling3DRel) + .set_attr("FInferCorrectLayout", + UpsamplingInferCorrectLayout) + .set_attr("TOpPattern", kInjective); + } // namespace dyn } // namespace relay } // namespace tvm diff --git a/src/relay/op/make_op.h b/src/relay/op/make_op.h index fb3bf023140e..dc9ddee0f0bb 100644 --- a/src/relay/op/make_op.h +++ b/src/relay/op/make_op.h @@ -77,6 +77,9 @@ Expr MakeTopK(Expr data, int k, int axis, String ret_type, bool is_ascend, DataT Expr MakeUpSampling(Expr data, double scale_h, double scale_w, String layout, String method, bool align_corners); +Expr MakeUpSampling3D(Expr data, double scale_d, double scale_h, double scale_w, String layout, + String method, String coordinate_transformation_mode); + Expr MakeVariance(Expr data, Expr mean, Array axis, bool keepdims, bool exclude, bool unbiased); diff --git a/src/relay/transforms/dynamic_to_static.cc b/src/relay/transforms/dynamic_to_static.cc index 629f5afe9612..0c417ad857a2 100644 --- a/src/relay/transforms/dynamic_to_static.cc +++ b/src/relay/transforms/dynamic_to_static.cc @@ -139,6 +139,25 @@ class DynamicToStaticMutator : public MixedModeMutator { } return Expr(nullptr); }}, + {Op::Get("dyn.nn.upsampling3d"), + [](const CallNode* call_node) { + const ConstantNode* scale_d = call_node->args[1].as(); + const ConstantNode* scale_h = call_node->args[2].as(); + const ConstantNode* scale_w = call_node->args[3].as(); + if (scale_d && scale_h && scale_w) { + CHECK_EQ(scale_d->data->ndim, 0); + CHECK_EQ(scale_h->data->ndim, 0); + CHECK_EQ(scale_w->data->ndim, 0); + const UpSampling3DAttrs* param = call_node->attrs.as(); + CHECK(param); + + return MakeUpSampling3D(call_node->args[0], ToScalar(scale_d->data), + ToScalar(scale_h->data), ToScalar(scale_w->data), + param->layout, param->method, + param->coordinate_transformation_mode); + } + return Expr(nullptr); + }}, {Op::Get("dyn.nn.pad"), [](const CallNode* call_node) { const ConstantNode* pad_width = call_node->args[1].as(); diff --git a/tests/python/relay/dyn/test_dynamic_op_level2.py b/tests/python/relay/dyn/test_dynamic_op_level2.py index 15b6b7acd7e9..b863d09db0a5 100644 --- a/tests/python/relay/dyn/test_dynamic_op_level2.py +++ b/tests/python/relay/dyn/test_dynamic_op_level2.py @@ -51,16 +51,16 @@ def verify_upsampling(dshape, scale_h, scale_w, layout, method, align_corners=Fa zz = run_infer_type(z) func = relay.Function([x, scale_h_var, scale_w_var], z) - for target, ctx in tvm.testing.enabled_targets(): - for kind in ["vm", "debug"]: - mod = tvm.ir.IRModule.from_expr(func) - intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) - op_res = intrp.evaluate()(x_data, np.array(scale_h).astype("float32"), np.array(scale_w).astype("float32")) - tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-4, atol=1e-6) - - verify_upsampling((1, 16, 32, 32), 2.0, 2.0,"NCHW", "nearest_neighbor") - verify_upsampling((1, 16, 32, 32), 2.0, 2.0, "NCHW", "bilinear", True) - verify_upsampling((1, 16, 32, 32), 2.0, 2.0, "NHWC", "nearest_neighbor") + for target, ctx in enabled_targets(): + for kind in ["vm", "debug"]: + mod = tvm.ir.IRModule.from_expr(func) + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) + op_res = intrp.evaluate()(x_data, np.array(scale_h).astype("float32"), np.array(scale_w).astype("float32")) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-4, atol=1e-6) + + verify_upsampling((1, 16, 32, 32), 3, 2.0,"NCHW", "nearest_neighbor") + verify_upsampling((1, 16, 32, 32), 5, 2.0, "NCHW", "bilinear", True) + verify_upsampling((1, 16, 32, 32), 2.0, 6, "NHWC", "nearest_neighbor") verify_upsampling((1, 16, 32, 32), 2.0, 2.0,"NHWC", "bilinear", True) #tests upsampling type inference with scale_h passed in as a constant and scale_w as a variable @@ -68,13 +68,64 @@ def test_dyn_upsampling_infer_type_const(): n, c, h, w = te.size_var("n"), te.size_var("c"), te.size_var("h"), te.size_var("w") data = relay.var("data", relay.TensorType((n, c, h, w), "int8")) - scale_h = relay.Var("scale_h", relay.TensorType((), "float32")) scale_w = relay.Var("scale_w", relay.TensorType((), "float32")) z = relay.nn.upsampling(data, 2.0, scale_w) zz = run_infer_type(z) assert zz.checked_type == relay.TensorType((n, c, relay.Any(), relay.Any()), "int8") +def test_dyn_upsampling3d_run(): + def verify_upsampling3d(dshape, scale_d, scale_h, scale_w, layout, method, coord_trans="half_pixel"): + + if layout == "NCDHW": + (n, c, d, h, w) = dshape + x_data = np.random.uniform(size=(n, c, d, h, w)).astype("float32") + + elif layout == "NDHWC": + (n, d, h, w, c) = dshape + x_data = np.random.uniform(size=(n, d, h, w, c)).astype("float32") + + if method == "nearest_neighbor": + ref_res = tvm.topi.testing.upsampling3d_python(x_data, (scale_d, scale_h, scale_w), layout) + else: + ref_res = tvm.topi.testing.trilinear_resize3d_python(x_data, (int(round(d*scale_d)), + int(round(h*scale_h)), + int(round(w*scale_w))), layout) + x = relay.Var("x", relay.TensorType(dshape, "float32")) + scale_d_var = relay.var("scale_d", relay.TensorType((), "float32")) + scale_h_var = relay.var("scale_h", relay.TensorType((), "float32")) + scale_w_var = relay.var("scale_h", relay.TensorType((), "float32")) + + z = relay.nn.upsampling3d(x, scale_d_var, scale_h_var, scale_w_var, method=method, layout=layout, + coordinate_transformation_mode=coord_trans) + zz = run_infer_type(z) + func = relay.Function([x, scale_d_var, scale_h_var, scale_w_var], z) + + for target, ctx in enabled_targets(): + for kind in ["vm", "debug"]: + mod = tvm.ir.IRModule.from_expr(func) + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) + op_res = intrp.evaluate()(x_data, np.array(scale_d).astype("float32"), np.array(scale_h).astype("float32"), np.array(scale_w).astype("float32")) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-4, atol=1e-6) + + verify_upsampling3d((1, 1, 1, 1, 1), 2, 3, 4, "NCDHW", "nearest_neighbor") + verify_upsampling3d((1, 8, 16, 16, 16), 2.0, 3.0, 4.0, "NCDHW", "nearest_neighbor") + verify_upsampling3d((1, 8, 16, 16, 16), 2.0, 5.0, 1.0, "NCDHW", "trilinear", "align_corners") + verify_upsampling3d((1, 20, 3, 4, 16), 2.0, 2.0, 2.0, "NDHWC", "nearest_neighbor") + verify_upsampling3d((1, 8, 4, 16, 15), 2.0, 2.0, 2.0,"NDHWC", "trilinear", "align_corners") + +#tests upsampling type inference with scale_h passed in as a constant and scale_w as a variable +def test_dyn_upsampling3d_infer_type_const(): + n, c, d, h, w = te.size_var("n"), te.size_var("c"), te.size_var("d"), te.size_var("h"), te.size_var("w") + + data = relay.var("data", relay.TensorType((n, c, d, h, w), "int8")) + scale_d = relay.Var("scale_h", relay.TensorType((), "float32")) + scale_w = relay.Var("scale_w", relay.TensorType((), "float32")) + + z = relay.nn.upsampling3d(data, scale_d, 2.0, scale_w, layout="NCDHW", method="trilinear") + zz = run_infer_type(z) + assert zz.checked_type == relay.TensorType((n, c, relay.Any(), relay.Any(), relay.Any()), "int8") + def test_dyn_pad(): def verify_pad(dshape, pad_width, pad_val, dtype): x = relay.var("x", relay.TensorType(dshape, dtype)) diff --git a/tests/python/relay/test_pass_dynamic_to_static.py b/tests/python/relay/test_pass_dynamic_to_static.py index 453c469d2c07..d1bf846d8aec 100644 --- a/tests/python/relay/test_pass_dynamic_to_static.py +++ b/tests/python/relay/test_pass_dynamic_to_static.py @@ -354,6 +354,30 @@ def verify_upsampling(data_shape, scale_h_val, scale_w_val, dtype): verify_upsampling((1, 16, 32, 32), 2, 2, 'int8') verify_upsampling((1, 16, 32, 32), 4, 4, 'int32') +def test_dynamic_to_static_upsampling3d(): + def verify_upsampling3d(data_shape, scale_d_val, scale_h_val, scale_w_val, dtype): + x = relay.var("x", relay.TensorType(data_shape, dtype)) + scale_d = relay.const(scale_d_val) + scale_h = relay.const(scale_h_val) + scale_w = relay.const(scale_w_val) + + z = relay.nn.upsampling3d(x, scale_d, scale_h, scale_w) + + func = run_infer_type(relay.Function([x], z)) + func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), transform.InferType()) + + zz = func2.body + assert isinstance(zz, relay.Call) + assert zz.op == relay.op.get("nn.upsampling3d") + + x_data = np.random.uniform(size=data_shape).astype(dtype) + ref_res = tvm.topi.testing.upsampling3d_python(x_data, (scale_d_val, scale_h_val, scale_w_val), "NCDHW") + verify_func(func2, [x_data], ref_res) + + verify_upsampling3d((1, 1, 1, 1, 1), 2, 3, 4, 'int8') + verify_upsampling3d((5, 7, 8, 10, 32), 3, 2, 2, 'int8') + verify_upsampling3d((1, 4, 2, 5, 3), 5, 4, 3, 'int32') + def test_dynamic_to_static_pad(): def verify_pad(data_shape, pad_width, pad_val, dtype): x = relay.var("x", relay.TensorType(data_shape, dtype))