diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index ff611d1f44db..5b84942a57cf 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -124,7 +124,8 @@ struct Conv2DAttrs : public tvm::AttrsNode { tvm::String data_layout; tvm::String kernel_layout; tvm::String out_layout; - tvm::String auto_scheduler_rewritten_layout; // The layout after auto-scheduler's layout rewrite + tvm::String auto_scheduler_rewritten_layout; // The layout after auto-scheduler's layout rewrite + Array meta_schedule_original_shape; // The original shape of the weights DataType out_dtype; TVM_DECLARE_ATTRS(Conv2DAttrs, "relay.attrs.Conv2DAttrs") { @@ -217,7 +218,8 @@ struct Conv2DWinogradAttrs : public tvm::AttrsNode { tvm::String data_layout; tvm::String kernel_layout; tvm::String out_layout; - tvm::String auto_scheduler_rewritten_layout; // The layout after auto-scheduler's layout rewrite + tvm::String auto_scheduler_rewritten_layout; // The layout after auto-scheduler's layout rewrite + Array meta_schedule_original_shape; // The original shape of the weights DataType out_dtype; TVM_DECLARE_ATTRS(Conv2DWinogradAttrs, "relay.attrs.Conv2DWinogradAttrs") { @@ -308,7 +310,8 @@ struct Conv3DAttrs : public tvm::AttrsNode { tvm::String data_layout; tvm::String kernel_layout; tvm::String out_layout; - tvm::String auto_scheduler_rewritten_layout; // The layout after auto-scheduler's layout rewrite + tvm::String auto_scheduler_rewritten_layout; // The layout after auto-scheduler's layout rewrite + Array meta_schedule_original_shape; // The original shape of the weights DataType out_dtype; TVM_DECLARE_ATTRS(Conv3DAttrs, "relay.attrs.Conv3DAttrs") { @@ -1049,7 +1052,8 @@ struct MatmulAttrs : public tvm::AttrsNode { DataType out_dtype; bool transpose_a; bool transpose_b; - tvm::String auto_scheduler_rewritten_layout; // The layout after auto-scheduler's layout rewrite + tvm::String auto_scheduler_rewritten_layout; // The layout after auto-scheduler's layout rewrite + Array meta_schedule_original_shape; // The original shape of the weights TVM_DECLARE_ATTRS(MatmulAttrs, "relay.attrs.MatmulAttrs") { TVM_ATTR_FIELD(units).describe("Number of hidden units of the dense transformation."); @@ -1072,7 +1076,8 @@ struct MatmulAttrs : public tvm::AttrsNode { /*! \brief Attributes for dense operator */ struct DenseAttrs : public tvm::AttrsNode { IndexExpr units; - tvm::String auto_scheduler_rewritten_layout; // The layout after auto-scheduler's layout rewrite + tvm::String auto_scheduler_rewritten_layout; // The layout after auto-scheduler's layout rewrite + Array meta_schedule_original_shape; // The original shape of the weights DataType out_dtype; TVM_DECLARE_ATTRS(DenseAttrs, "relay.attrs.DenseAttrs") { @@ -1109,7 +1114,8 @@ struct BatchMatmulAttrs : public tvm::AttrsNode { DataType out_dtype; bool transpose_a; bool transpose_b; - tvm::String auto_scheduler_rewritten_layout; // The layout after auto-scheduler's layout rewrite + tvm::String auto_scheduler_rewritten_layout; // The layout after auto-scheduler's layout rewrite + Array meta_schedule_original_shape; // The original shape of the weights TVM_DECLARE_ATTRS(BatchMatmulAttrs, "relay.attrs.BatchMatmulAttrs") { // use 0 bits to indicate none. diff --git a/src/relay/op/nn/convolution.cc b/src/relay/op/nn/convolution.cc index a6f6390b2110..a5e2d9f51cd7 100644 --- a/src/relay/op/nn/convolution.cc +++ b/src/relay/op/nn/convolution.cc @@ -188,6 +188,18 @@ bool Conv2DRel(const Array& types, int num_inputs, const Attrs& attrs, Layout kOIHW("OIHW"); const auto* param = attrs.as(); + DataType out_dtype = param->out_dtype; + if (out_dtype.bits() == 0) { + out_dtype = data->dtype; + if (out_dtype.bits() == 0 && weight != nullptr) { + out_dtype = weight->dtype; + } + } + TensorType meta_schedule_weight{nullptr}; + if (param->meta_schedule_original_shape.size() != 0) { + meta_schedule_weight = TensorType(param->meta_schedule_original_shape, out_dtype); + weight = meta_schedule_weight.get(); + } ICHECK(param != nullptr); const Layout in_layout(param->data_layout); const Layout kernel_layout(param->kernel_layout); @@ -273,27 +285,27 @@ bool Conv2DRel(const Array& types, int num_inputs, const Attrs& attrs, weight_dtype = weight->dtype; } - if (param->auto_scheduler_rewritten_layout.size() == 0) { - // Normal case: assign result to reporter - reporter->Assign(types[1], TensorType(wshape, weight_dtype)); - } else { + if (param->auto_scheduler_rewritten_layout.size() != 0) { // If the layout is rewritten by auto-scheduler, // we just forcly apply the layout provided by auto-scheduler and // skip the normal inference logic. {} // do nothing + } else { + // Normal case: assign result to reporter + reporter->Assign(types[1], TensorType(wshape, weight_dtype)); } } else { // use weight to infer the conv shape. if (weight == nullptr) return false; Array wshape; - if (param->auto_scheduler_rewritten_layout.size() == 0) { - wshape = weight->shape; - } else { + if (param->auto_scheduler_rewritten_layout.size() != 0) { // works for the default kernel layout "HWIO" ICHECK_EQ(param->kernel_layout, "HWIO"); wshape = auto_scheduler::GetShapeFromRewrittenLayout(param->auto_scheduler_rewritten_layout, {"ry", "rx", "rc", "ff"}); + } else { + wshape = weight->shape; } wshape = trans_kernel_layout.ForwardShape(wshape); @@ -357,10 +369,6 @@ bool Conv2DRel(const Array& types, int num_inputs, const Attrs& attrs, } else { oshape.Set(3, dshape_nchw[3]); } - DataType out_dtype = param->out_dtype; - if (out_dtype.bits() == 0) { - out_dtype = data->dtype; - } oshape = trans_out_layout.BackwardShape(oshape); // assign output type reporter->Assign(types[2], TensorType(oshape, out_dtype)); @@ -412,6 +420,18 @@ bool Conv3DRel(const Array& types, int num_inputs, const Attrs& attrs, const auto* param = attrs.as(); ICHECK(param != nullptr); + DataType out_dtype = param->out_dtype; + if (out_dtype.bits() == 0) { + out_dtype = data->dtype; + if (out_dtype.bits() == 0 && weight != nullptr) { + out_dtype = weight->dtype; + } + } + TensorType meta_schedule_weight{nullptr}; + if (param->meta_schedule_original_shape.size() != 0) { + meta_schedule_weight = TensorType(param->meta_schedule_original_shape, out_dtype); + weight = meta_schedule_weight.get(); + } const Layout in_layout(param->data_layout); const Layout kernel_layout(param->kernel_layout); @@ -450,14 +470,14 @@ bool Conv3DRel(const Array& types, int num_inputs, const Attrs& attrs, weight_dtype = weight->dtype; } - if (param->auto_scheduler_rewritten_layout.size() == 0) { - // Normal case: assign result to reporter - reporter->Assign(types[1], TensorType(wshape, weight_dtype)); - } else { + if (param->auto_scheduler_rewritten_layout.size() != 0) { // If the layout is rewritten by auto-scheduler, // we just forcly apply the layout provided by auto-scheduler and // skip the normal inference logic. {} // do nothing + } else { + // Normal case: assign result to reporter + reporter->Assign(types[1], TensorType(wshape, weight_dtype)); } } else { @@ -465,13 +485,13 @@ bool Conv3DRel(const Array& types, int num_inputs, const Attrs& attrs, if (weight == nullptr) return false; Array wshape; - if (param->auto_scheduler_rewritten_layout.size() == 0) { - wshape = weight->shape; - } else { + if (param->auto_scheduler_rewritten_layout.size() != 0) { // works for the default kernel layout "DHWIO" ICHECK_EQ(param->kernel_layout, "DHWIO"); wshape = auto_scheduler::GetShapeFromRewrittenLayout(param->auto_scheduler_rewritten_layout, {"rd", "rh", "rw", "rc", "cc"}); + } else { + wshape = weight->shape; } wshape = trans_kernel_layout.ForwardShape(wshape); @@ -521,10 +541,6 @@ bool Conv3DRel(const Array& types, int num_inputs, const Attrs& attrs, } else { oshape.Set(4, dshape_ncdhw[4]); } - DataType out_dtype = param->out_dtype; - if (out_dtype.bits() == 0) { - out_dtype = data->dtype; - } oshape = trans_out_layout.BackwardShape(oshape); // assign output type reporter->Assign(types[2], TensorType(oshape, out_dtype)); diff --git a/src/relay/op/nn/nn.h b/src/relay/op/nn/nn.h index 6bc21473af18..33d4c946e408 100644 --- a/src/relay/op/nn/nn.h +++ b/src/relay/op/nn/nn.h @@ -48,6 +48,12 @@ bool MatmulRel(const Array& types, int num_inputs, const Attrs& attrs, const AttrType* param = attrs.as(); ICHECK(param != nullptr); + TensorType meta_schedule_tensor_b{nullptr}; + if (param->meta_schedule_original_shape.size() > 0) { + meta_schedule_tensor_b = TensorType(param->meta_schedule_original_shape, + tensor_b == nullptr ? tensor_a->dtype : tensor_b->dtype); + tensor_b = meta_schedule_tensor_b.get(); + } // Default set to dense layout bool transpose_a = false; bool transpose_b = true; @@ -73,14 +79,14 @@ bool MatmulRel(const Array& types, int num_inputs, const Attrs& attrs, // data dtype as the tensor_b dtype. However if tensor_b dtype is explicitly // present we will use that. auto tensor_b_dtype = (tensor_b == nullptr ? tensor_a->dtype : tensor_b->dtype); - if (param->auto_scheduler_rewritten_layout.size() == 0) { - // Normal case: assign result to reporter - reporter->Assign(types[1], TensorType(wshape, tensor_b_dtype)); - } else { - // If the layout is rewritten by auto-scheduler, - // we just forcly apply the layout provided by auto-scheduler and + if (param->auto_scheduler_rewritten_layout.size() != 0) { + // If the layout is rewritten by auto-scheduler or meta-schedule, + // we just forcefully apply the layout provided by auto-scheduler and // skip the normal inference logic. {} // do nothing + } else { + // Normal case: assign result to reporter + reporter->Assign(types[1], TensorType(wshape, tensor_b_dtype)); } oshape.Set((oshape.size() - 1), param->units); } else { @@ -103,7 +109,7 @@ bool MatmulRel(const Array& types, int num_inputs, const Attrs& attrs, << "MatmulRel: input dimension doesn't match," << " tensor_a shape=" << tensor_a->shape << ", tensor_b shape=" << tensor_b->shape; } - oshape.Set((oshape.size() - 1), transpose_b ? wshape[0] : wshape[1]); + oshape.Set(oshape.size() - 1, transpose_b ? wshape[0] : wshape[1]); } } @@ -125,16 +131,32 @@ bool BatchMatmulRel(const Array& types, int num_inputs, const Attrs& attrs if (x == nullptr || y == nullptr) return false; const AttrType* param = attrs.as(); + DataType out_dtype = param->out_dtype; + if (out_dtype.bits() == 0) { + out_dtype = x->dtype; + if (x->dtype.bits() == 0) { + out_dtype = y->dtype; + } + } + TensorType meta_schedule_y{nullptr}; + if (param->meta_schedule_original_shape.size() != 0) { + meta_schedule_y = TensorType(param->meta_schedule_original_shape, out_dtype); + y = meta_schedule_y.get(); + } ICHECK(param != nullptr); bool transpose_a = param->transpose_a; bool transpose_b = param->transpose_b; - const Array& y_shape = - param->auto_scheduler_rewritten_layout.size() == 0 - ? y->shape - : auto_scheduler::GetShapeFromRewrittenLayout( - param->auto_scheduler_rewritten_layout, - transpose_b ? tvm::runtime::Array({"b", "j", "k"}) - : tvm::runtime::Array({"b", "k", "j"})); + Array y_shape{nullptr}; + if (param->auto_scheduler_rewritten_layout.size() != 0) { + y_shape = auto_scheduler::GetShapeFromRewrittenLayout( + param->auto_scheduler_rewritten_layout, + transpose_b ? tvm::runtime::Array({"b", "j", "k"}) + : tvm::runtime::Array({"b", "k", "j"})); + } else if (param->meta_schedule_original_shape.size() != 0) { + y_shape = param->meta_schedule_original_shape; + } else { + y_shape = y->shape; + } ICHECK(x->shape.size() == 3 && y_shape.size() == 3); const PrimExpr& xb = x->shape[0]; const PrimExpr& xi = x->shape[transpose_a ? 2 : 1]; @@ -158,10 +180,6 @@ bool BatchMatmulRel(const Array& types, int num_inputs, const Attrs& attrs << " x shape=" << x->shape << ", y shape=" << y_shape; } - DataType out_dtype = param->out_dtype; - if (out_dtype.bits() == 0) { - out_dtype = x->dtype; - } // assign output type const auto& out_b = xb->IsInstance() || yb->IsInstance() ? tir::Any() : max(xb, yb); diff --git a/src/relay/transforms/fold_explicit_padding.cc b/src/relay/transforms/fold_explicit_padding.cc index 00162abc69f9..37385f80c1c9 100644 --- a/src/relay/transforms/fold_explicit_padding.cc +++ b/src/relay/transforms/fold_explicit_padding.cc @@ -126,6 +126,7 @@ class SimplifyExplicitPad { T* new_attrs = const_cast(attrs.template as()); new_attrs->auto_scheduler_rewritten_layout = old_attrs->auto_scheduler_rewritten_layout; + new_attrs->meta_schedule_original_shape = old_attrs->meta_schedule_original_shape; return attrs; }