diff --git a/cinn/backends/codegen_c_x86.cc b/cinn/backends/codegen_c_x86.cc index 3730dad793490..eecbf0c2fce04 100644 --- a/cinn/backends/codegen_c_x86.cc +++ b/cinn/backends/codegen_c_x86.cc @@ -14,13 +14,11 @@ void CodeGenCX86::Visit(const ir::Load *op) { CHECK(op->type().is_vector()); int bits = op->type().bits() * op->type().lanes(); - if (SupportsAVX512()) { - CHECK_EQ(bits, 512); + if (SupportsAVX512() && bits == 512) { os() << "cinn_avx512_load("; PrintAbsAddr(op); os() << ")"; - } else if (SupportsAVX256()) { - CHECK_EQ(bits, 256); + } else if (SupportsAVX256() && bits == 256) { os() << "cinn_avx256_load("; PrintAbsAddr(op); os() << ")"; @@ -36,13 +34,11 @@ void CodeGenCX86::Visit(const ir::Broadcast *op) { CHECK_GT(op->type().lanes(), 1); int bits = op->type().bits() * op->type().lanes(); - if (SupportsAVX512()) { - CHECK_EQ(bits, 512); + if (SupportsAVX512() && bits == 512) { os() << "cinn_avx512_set1("; PrintCastExpr(op->value.type().ElementOf(), op->value); os() << ")"; - } else if (SupportsAVX256()) { - CHECK_EQ(bits, 256); + } else if (SupportsAVX256() && bits == 256) { os() << "cinn_avx256_set1("; PrintCastExpr(op->value.type().ElementOf(), op->value); os() << ")"; @@ -58,15 +54,13 @@ void CodeGenCX86::Visit(const ir::Store *op) { } int bits = op->type().bits() * op->type().lanes(); - if (SupportsAVX512()) { - CHECK_EQ(bits, 512); + if (SupportsAVX512() && bits == 512) { os() << "cinn_avx512_store("; PrintAbsAddr(op); os() << ", "; Print(op->value); os() << ")"; - } else if (SupportsAVX256()) { - CHECK_EQ(bits, 256); + } else if (SupportsAVX256() && bits == 256) { os() << "cinn_avx256_store("; PrintAbsAddr(op); os() << ", "; diff --git a/cinn/backends/codegen_c_x86.h b/cinn/backends/codegen_c_x86.h index 755bf70e58711..b6326152bf077 100644 --- a/cinn/backends/codegen_c_x86.h +++ b/cinn/backends/codegen_c_x86.h @@ -95,22 +95,20 @@ void CodeGenCX86::VisitBinaryOp(const Op *op, Expr a, Expr b, const std::string // TODO(Superjomn) Consider support BLAS. int bits = a.type().bits() * a.type().lanes(); - if (SupportsAVX512()) { - CHECK_EQ(bits, 512) << "the bits of computation should be times of 512"; + if (SupportsAVX512() && bits == 512) { os() << "cinn_avx512_" << op_repr << "("; PrintVecInputArgument(&a); os() << ", "; PrintVecInputArgument(&b); os() << ")"; - } else if (SupportsAVX256()) { - CHECK_EQ(bits, 256) << "the bits of computation should be times of 256"; + } else if (SupportsAVX256() && bits == 256) { os() << "cinn_avx256_" << op_repr << "("; PrintVecInputArgument(&a); os() << ", "; PrintVecInputArgument(&b); os() << ")"; } else { - CINN_NOT_IMPLEMENTED + CodeGenC::Visit(op); } } diff --git a/cinn/backends/llvm/codegen_llvm.cc b/cinn/backends/llvm/codegen_llvm.cc index 6a1233e8327d8..33317da59c14d 100644 --- a/cinn/backends/llvm/codegen_llvm.cc +++ b/cinn/backends/llvm/codegen_llvm.cc @@ -803,7 +803,7 @@ llvm::Value *CodeGenLLVM::Visit(const ir::Store *op) { auto *vtype = llvm::VectorType::get(ll_type_of(op->type().ElementOf()), llvm::ElementCount(lanes, false /*Scalable*/)) ->getPointerTo(); - int alignment = lanes * op->type().ElementOf().bits() / 8; + int alignment = std::max(op->type().ElementOf().bits() / 8, 1); llvm::StoreInst *inst = b_->CreateAlignedStore(CreateVecSlice(value, offset, lanes), b_->CreatePointerCast(ptr, vtype), alignment); AddTbaaMetadata(inst, op->tensor.as_tensor()->name, base); @@ -1086,7 +1086,7 @@ llvm::Value *CodeGenLLVM::DenseVectorLoad(const ir::Load *op) { llvm::Value *elt_ptr = CreateBufferPtr(op->type().ElementOf(), buffer, Visit(&slice_base)); llvm::Value *vec_ptr = b_->CreatePointerCast(elt_ptr, slice_type->getPointerTo(), "get_vec_ptr"); - int alignment = slice_lanes * op->type().ElementOf().bits() / 8; + int alignment = std::max(op->type().ElementOf().bits() / 8, 1); llvm::Instruction *load_inst = b_->CreateAlignedLoad(vec_ptr, llvm::Align(alignment), "load_vec"); AddTbaaMetadata(load_inst, op->tensor.as_tensor()->name, op->index()); diff --git a/cinn/hlir/framework/tensor.h b/cinn/hlir/framework/tensor.h index 5867a680fa84f..8eb4bdd9bb266 100644 --- a/cinn/hlir/framework/tensor.h +++ b/cinn/hlir/framework/tensor.h @@ -49,7 +49,7 @@ class _Tensor_ : public Object { inline T* mutable_data(const Target& target) { set_type(type_of()); if (target == common::DefaultHostTarget()) { - int alignment = target.get_target_bits() * 8; + int alignment = type_of().ElementOf().bits(); buffer_->ResizeLazy(alignment, shape_.numel() * sizeof(T), target); } else { buffer_->ResizeLazy(shape_.numel() * sizeof(T), target); diff --git a/cinn/hlir/op/nn.cc b/cinn/hlir/op/nn.cc index 03f7f8285ca55..8d5fb7f0e1525 100755 --- a/cinn/hlir/op/nn.cc +++ b/cinn/hlir/op/nn.cc @@ -40,10 +40,15 @@ std::shared_ptr StrategyForRelu(const framework::NodeAttr &attrs, CINNValuePack arg_pack = args[0]; CHECK_EQ(arg_pack.size(), 2UL); if (target.arch == Target::Arch::NVGPU) { - Expr Out = arg_pack[0]; + Expr out = arg_pack[0]; poly::StageMap stages = arg_pack[1]; - CHECK(Out.as_tensor()); - pe::CudaScheduleInjective(stages[Out.as_tensor_ref()], output_shapes.back(), target); + CHECK(out.as_tensor()); + pe::CudaScheduleInjective(stages[out.as_tensor_ref()], output_shapes.back(), target); + } else if (target.arch == Target::Arch::X86) { + Expr out = arg_pack[0]; + poly::StageMap stages = arg_pack[1]; + CHECK(out.as_tensor()); + pe::ScheduleInjectiveCPU(stages[out.as_tensor_ref()], output_shapes.back(), target); } *ret = arg_pack; }); @@ -92,10 +97,15 @@ std::shared_ptr StrategyForRelu6(const framework::NodeAttr &attrs, CINNValuePack arg_pack = args[0]; CHECK_EQ(arg_pack.size(), 2UL); if (target.arch == Target::Arch::NVGPU) { - Expr Out = arg_pack[0]; + Expr out = arg_pack[0]; poly::StageMap stages = arg_pack[1]; - CHECK(Out.as_tensor()); - pe::CudaScheduleInjective(stages[Out.as_tensor_ref()], output_shapes.back(), target); + CHECK(out.as_tensor()); + pe::CudaScheduleInjective(stages[out.as_tensor_ref()], output_shapes.back(), target); + } else if (target.arch == Target::Arch::X86) { + Expr out = arg_pack[0]; + poly::StageMap stages = arg_pack[1]; + CHECK(out.as_tensor()); + pe::ScheduleInjectiveCPU(stages[out.as_tensor_ref()], output_shapes.back(), target); } *ret = arg_pack; }); diff --git a/cinn/hlir/pe/nn.h b/cinn/hlir/pe/nn.h index 5e0ba4577c0ba..9f0daebe8f2b9 100644 --- a/cinn/hlir/pe/nn.h +++ b/cinn/hlir/pe/nn.h @@ -214,7 +214,9 @@ ir::Tensor Pad(const ir::Tensor &tensor, const std::string &name = UniqName("T_pad_out"), const std::string &pad_mode = "constant"); -std::vector Softmax(const ir::Tensor &A, int axis=-1, const std::string &output_name= UniqName("T_softmax_out")); +std::vector Softmax(const ir::Tensor &A, + int axis = -1, + const std::string &output_name = UniqName("T_softmax_out")); ir::Tensor Slice(const ir::Tensor &A, const std::vector &starts, diff --git a/cinn/hlir/pe/schedule.cc b/cinn/hlir/pe/schedule.cc index cbbbaccb6107f..4f854aa4ca6dc 100644 --- a/cinn/hlir/pe/schedule.cc +++ b/cinn/hlir/pe/schedule.cc @@ -7,6 +7,15 @@ namespace cinn { namespace hlir { namespace pe { +int GetBetterSplitFactor(int shape, int split_factor) { + int better_factor = split_factor; + while (better_factor > shape) { + better_factor /= 2; + } + if (better_factor < shape) return better_factor * 2; + return better_factor; +} + void ScheduleInjectiveCPU(poly::Stage *stage, const std::vector &output_shape, const common::Target &target) { int dims = stage->n_out_dims(); if (dims > 1) { @@ -24,19 +33,19 @@ void ScheduleInjectiveCPU(poly::Stage *stage, const std::vector &output_sha if (last_two_dim_bits % target_native_vector_bits == 0) { fused = stage->Fuse(dims - 2, dims - 1); prod_size *= output_shape[dims - 2]; - } else { - return; } } int split_factor = target_native_vector_bits / type_bits; - if (prod_size == split_factor) { - stage->Vectorize(fused, split_factor); - return; + if (prod_size <= split_factor) { + split_factor = GetBetterSplitFactor(prod_size, split_factor); + if (split_factor >= 8) { + stage->Vectorize(fused, split_factor); + } + } else { + auto [j_outer, j_inner] = stage->Split(fused, split_factor); + stage->Vectorize(j_inner, split_factor); } - auto [j_outer, j_inner] = stage->Split(fused, split_factor); - stage->Vectorize(j_inner, split_factor); } - return; } diff --git a/cinn/optim/vectorize_loops.cc b/cinn/optim/vectorize_loops.cc index 91f4cfa344f79..50dda8c7297d4 100644 --- a/cinn/optim/vectorize_loops.cc +++ b/cinn/optim/vectorize_loops.cc @@ -314,7 +314,7 @@ class Vectorizer : public IRMutator { } template - Expr BinaryOperatorVec(const T *op, Expr *expr) { + void BinaryOperatorVec(const T *op, Expr *expr) { auto *node = expr->As(); Expr a0 = node->a(); Expr b0 = node->b(); @@ -323,7 +323,7 @@ class Vectorizer : public IRMutator { // if (a0.same_as(node->a()) && b0.same_as(node->b())) return *expr; int lanes = std::max(node->a().type().lanes(), node->b().type().lanes()); - return T::Make(Widen(node->a(), lanes), Widen(node->b(), lanes)); + *expr = T::Make(Widen(node->a(), lanes), Widen(node->b(), lanes)); } }; @@ -521,8 +521,17 @@ struct VectorizeLoops_ : public IRMutator { if (!for_min_i) return Expr(); if (for_min_i->value != 0) return Expr(); - Expr times = common::AutoSimplify(Div::Make(forloop->extent, make_const(factor))); - Simplify(×); + auto *extent_ptr = forloop->extent.As(); + Expr times; + if (extent_ptr) { + int extent_int = forloop->extent.as_int32(); + int extent_trunc = extent_int / factor; + int extent_times = extent_int % factor == 0 ? extent_trunc : extent_trunc + 1; + times = common::make_const(forloop->extent->type(), extent_times); + } else { + times = common::AutoSimplify(Div::Make(forloop->extent, make_const(factor))); + Simplify(×); + } // update the current forloop auto times_int = times.As(); diff --git a/cinn/optim/vectorize_loops_test.cc b/cinn/optim/vectorize_loops_test.cc index 8eec58a11cfec..6d9c0f23d9dc2 100644 --- a/cinn/optim/vectorize_loops_test.cc +++ b/cinn/optim/vectorize_loops_test.cc @@ -65,7 +65,7 @@ void matmul(void* _args, int32_t num_args) const float* B = ((const float*)(_B->memory)); float* C = ((float*)(_C->memory)); for (int32_t i = 0; i < 100; i += 1) { - for (int32_t j = 0; j < 31; j += 1) { + for (int32_t j = 0; j < 32; j += 1) { C[StackVec<16,int32_t>::Ramp(((500 * i) + (16 * j)), 1, 16)] = (StackedVec::Load(A,((500 * i) + (16 * j))) * StackedVec::Load(B,((500 * i) + (16 * j)))); }; }; @@ -136,11 +136,13 @@ void matmul(const struct cinn_buffer_t *_A, const struct cinn_buffer_t *_B, stru float* D = (float*)(_C->memory); for (int32_t i = 0; i < 100; i += 1) { for (int32_t j_outer = 0; j_outer < 31; j_outer += 1) { - C[StackVec<16,int32_t>::Ramp(((500 * i) + (16 * j_outer)), 1, 16)] = (StackedVec::Load(A,((500 * i) + (16 * j_outer))) * StackedVec::Load(B,((500 * i) + (16 * j_outer)))); + C[StackVec<16,int32_t>::Ramp(((500 * i) + (16 * j_outer)), 1, 16)] = (StackedVec::Load(A,((500 * i) + + (16 * j_outer))) * StackedVec::Load(B,((500 * i) + (16 * j_outer)))); }; for (int32_t j_outer = 31; j_outer < 32; j_outer += 1) { for (int32_t j_inner = 0; j_inner < (500 + (-16 * j_outer)); j_inner += 1) { - C[((500 * i) + ((16 * j_outer) + j_inner))] = (A[((500 * i) + ((16 * j_outer) + j_inner))] * B[((500 * i) + ((16 * j_outer) + j_inner))]); + C[((500 * i) + ((16 * j_outer) + j_inner))] = (A[((500 * i) + ((16 * j_outer) + j_inner))] * B[((500 * i) + + ((16 * j_outer) + j_inner))]); }; }; }; diff --git a/python/tests/test_utils.py b/python/tests/test_utils.py index ed5e9b547ae02..4acd8fdd06a23 100644 --- a/python/tests/test_utils.py +++ b/python/tests/test_utils.py @@ -69,7 +69,7 @@ def to_test_op(self, input_shapes, output_shapes, op_name, attrs): temp_inputs = [] alignment = 0 if self.target.arch == common.Target.Arch.X86: - alignment = 512 + alignment = 32 for in_data in inputs_data: temp_inputs.append( runtime.cinn_buffer_t(in_data, runtime.cinn_x86_device, diff --git a/tests/benchmark/test_utils.cc b/tests/benchmark/test_utils.cc index ed24e999f15a0..9966b0cfa2281 100644 --- a/tests/benchmark/test_utils.cc +++ b/tests/benchmark/test_utils.cc @@ -103,14 +103,14 @@ Module OpBenchmarkTester::CreateCinnModule(const std::vector& input_tens void OpBenchmarkTester::CreateBuffer() { std::vector args; for (size_t i = 0; i < input_shapes_.size(); i++) { - auto* buffer = common::BufferBuilder(input_types_[i], input_shapes_[i]).set_align(512).set_random().Build(); + auto* buffer = common::BufferBuilder(input_types_[i], input_shapes_[i]).set_align(32).set_random().Build(); cinn_pod_value_t arg(buffer); all_args_.push_back(arg); } CHECK(!output_shapes_.empty()) << "output shapes shouldn't be empty\n"; CHECK_EQ(output_shapes_.size(), out_types_.size()); for (size_t i = 0; i < output_shapes_.size(); i++) { - auto* buffer = common::BufferBuilder(out_types_[i], output_shapes_[i]).set_align(512).set_zero().Build(); + auto* buffer = common::BufferBuilder(out_types_[i], output_shapes_[i]).set_align(32).set_zero().Build(); CHECK(buffer); out_dims_ = buffer->num_elements(); cinn_pod_value_t arg(buffer);