Skip to content

Commit

Permalink
optimize relu, relu6; optimize schedule with ugly shape (PaddlePaddle…
Browse files Browse the repository at this point in the history
  • Loading branch information
wenming2014 authored Jan 7, 2021
1 parent dcd4775 commit 7fb48cc
Show file tree
Hide file tree
Showing 11 changed files with 69 additions and 45 deletions.
18 changes: 6 additions & 12 deletions cinn/backends/codegen_c_x86.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,11 @@ void CodeGenCX86::Visit(const ir::Load *op) {
CHECK(op->type().is_vector());

int bits = op->type().bits() * op->type().lanes();
if (SupportsAVX512()) {
CHECK_EQ(bits, 512);
if (SupportsAVX512() && bits == 512) {
os() << "cinn_avx512_load(";
PrintAbsAddr(op);
os() << ")";
} else if (SupportsAVX256()) {
CHECK_EQ(bits, 256);
} else if (SupportsAVX256() && bits == 256) {
os() << "cinn_avx256_load(";
PrintAbsAddr(op);
os() << ")";
Expand All @@ -36,13 +34,11 @@ void CodeGenCX86::Visit(const ir::Broadcast *op) {
CHECK_GT(op->type().lanes(), 1);
int bits = op->type().bits() * op->type().lanes();

if (SupportsAVX512()) {
CHECK_EQ(bits, 512);
if (SupportsAVX512() && bits == 512) {
os() << "cinn_avx512_set1(";
PrintCastExpr(op->value.type().ElementOf(), op->value);
os() << ")";
} else if (SupportsAVX256()) {
CHECK_EQ(bits, 256);
} else if (SupportsAVX256() && bits == 256) {
os() << "cinn_avx256_set1(";
PrintCastExpr(op->value.type().ElementOf(), op->value);
os() << ")";
Expand All @@ -58,15 +54,13 @@ void CodeGenCX86::Visit(const ir::Store *op) {
}

int bits = op->type().bits() * op->type().lanes();
if (SupportsAVX512()) {
CHECK_EQ(bits, 512);
if (SupportsAVX512() && bits == 512) {
os() << "cinn_avx512_store(";
PrintAbsAddr(op);
os() << ", ";
Print(op->value);
os() << ")";
} else if (SupportsAVX256()) {
CHECK_EQ(bits, 256);
} else if (SupportsAVX256() && bits == 256) {
os() << "cinn_avx256_store(";
PrintAbsAddr(op);
os() << ", ";
Expand Down
8 changes: 3 additions & 5 deletions cinn/backends/codegen_c_x86.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,22 +95,20 @@ void CodeGenCX86::VisitBinaryOp(const Op *op, Expr a, Expr b, const std::string

// TODO(Superjomn) Consider support BLAS.
int bits = a.type().bits() * a.type().lanes();
if (SupportsAVX512()) {
CHECK_EQ(bits, 512) << "the bits of computation should be times of 512";
if (SupportsAVX512() && bits == 512) {
os() << "cinn_avx512_" << op_repr << "(";
PrintVecInputArgument(&a);
os() << ", ";
PrintVecInputArgument(&b);
os() << ")";
} else if (SupportsAVX256()) {
CHECK_EQ(bits, 256) << "the bits of computation should be times of 256";
} else if (SupportsAVX256() && bits == 256) {
os() << "cinn_avx256_" << op_repr << "(";
PrintVecInputArgument(&a);
os() << ", ";
PrintVecInputArgument(&b);
os() << ")";
} else {
CINN_NOT_IMPLEMENTED
CodeGenC::Visit(op);
}
}

Expand Down
4 changes: 2 additions & 2 deletions cinn/backends/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -803,7 +803,7 @@ llvm::Value *CodeGenLLVM::Visit(const ir::Store *op) {
auto *vtype =
llvm::VectorType::get(ll_type_of(op->type().ElementOf()), llvm::ElementCount(lanes, false /*Scalable*/))
->getPointerTo();
int alignment = lanes * op->type().ElementOf().bits() / 8;
int alignment = std::max(op->type().ElementOf().bits() / 8, 1);
llvm::StoreInst *inst =
b_->CreateAlignedStore(CreateVecSlice(value, offset, lanes), b_->CreatePointerCast(ptr, vtype), alignment);
AddTbaaMetadata(inst, op->tensor.as_tensor()->name, base);
Expand Down Expand Up @@ -1086,7 +1086,7 @@ llvm::Value *CodeGenLLVM::DenseVectorLoad(const ir::Load *op) {
llvm::Value *elt_ptr = CreateBufferPtr(op->type().ElementOf(), buffer, Visit(&slice_base));
llvm::Value *vec_ptr = b_->CreatePointerCast(elt_ptr, slice_type->getPointerTo(), "get_vec_ptr");

int alignment = slice_lanes * op->type().ElementOf().bits() / 8;
int alignment = std::max(op->type().ElementOf().bits() / 8, 1);

llvm::Instruction *load_inst = b_->CreateAlignedLoad(vec_ptr, llvm::Align(alignment), "load_vec");
AddTbaaMetadata(load_inst, op->tensor.as_tensor()->name, op->index());
Expand Down
2 changes: 1 addition & 1 deletion cinn/hlir/framework/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class _Tensor_ : public Object {
inline T* mutable_data(const Target& target) {
set_type(type_of<T>());
if (target == common::DefaultHostTarget()) {
int alignment = target.get_target_bits() * 8;
int alignment = type_of<T>().ElementOf().bits();
buffer_->ResizeLazy(alignment, shape_.numel() * sizeof(T), target);
} else {
buffer_->ResizeLazy(shape_.numel() * sizeof(T), target);
Expand Down
22 changes: 16 additions & 6 deletions cinn/hlir/op/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,15 @@ std::shared_ptr<OpStrategy> StrategyForRelu(const framework::NodeAttr &attrs,
CINNValuePack arg_pack = args[0];
CHECK_EQ(arg_pack.size(), 2UL);
if (target.arch == Target::Arch::NVGPU) {
Expr Out = arg_pack[0];
Expr out = arg_pack[0];
poly::StageMap stages = arg_pack[1];
CHECK(Out.as_tensor());
pe::CudaScheduleInjective(stages[Out.as_tensor_ref()], output_shapes.back(), target);
CHECK(out.as_tensor());
pe::CudaScheduleInjective(stages[out.as_tensor_ref()], output_shapes.back(), target);
} else if (target.arch == Target::Arch::X86) {
Expr out = arg_pack[0];
poly::StageMap stages = arg_pack[1];
CHECK(out.as_tensor());
pe::ScheduleInjectiveCPU(stages[out.as_tensor_ref()], output_shapes.back(), target);
}
*ret = arg_pack;
});
Expand Down Expand Up @@ -92,10 +97,15 @@ std::shared_ptr<OpStrategy> StrategyForRelu6(const framework::NodeAttr &attrs,
CINNValuePack arg_pack = args[0];
CHECK_EQ(arg_pack.size(), 2UL);
if (target.arch == Target::Arch::NVGPU) {
Expr Out = arg_pack[0];
Expr out = arg_pack[0];
poly::StageMap stages = arg_pack[1];
CHECK(Out.as_tensor());
pe::CudaScheduleInjective(stages[Out.as_tensor_ref()], output_shapes.back(), target);
CHECK(out.as_tensor());
pe::CudaScheduleInjective(stages[out.as_tensor_ref()], output_shapes.back(), target);
} else if (target.arch == Target::Arch::X86) {
Expr out = arg_pack[0];
poly::StageMap stages = arg_pack[1];
CHECK(out.as_tensor());
pe::ScheduleInjectiveCPU(stages[out.as_tensor_ref()], output_shapes.back(), target);
}
*ret = arg_pack;
});
Expand Down
4 changes: 3 additions & 1 deletion cinn/hlir/pe/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,9 @@ ir::Tensor Pad(const ir::Tensor &tensor,
const std::string &name = UniqName("T_pad_out"),
const std::string &pad_mode = "constant");

std::vector<ir::Tensor> Softmax(const ir::Tensor &A, int axis=-1, const std::string &output_name= UniqName("T_softmax_out"));
std::vector<ir::Tensor> Softmax(const ir::Tensor &A,
int axis = -1,
const std::string &output_name = UniqName("T_softmax_out"));

ir::Tensor Slice(const ir::Tensor &A,
const std::vector<int> &starts,
Expand Down
25 changes: 17 additions & 8 deletions cinn/hlir/pe/schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,15 @@ namespace cinn {
namespace hlir {
namespace pe {

int GetBetterSplitFactor(int shape, int split_factor) {
int better_factor = split_factor;
while (better_factor > shape) {
better_factor /= 2;
}
if (better_factor < shape) return better_factor * 2;
return better_factor;
}

void ScheduleInjectiveCPU(poly::Stage *stage, const std::vector<int> &output_shape, const common::Target &target) {
int dims = stage->n_out_dims();
if (dims > 1) {
Expand All @@ -24,19 +33,19 @@ void ScheduleInjectiveCPU(poly::Stage *stage, const std::vector<int> &output_sha
if (last_two_dim_bits % target_native_vector_bits == 0) {
fused = stage->Fuse(dims - 2, dims - 1);
prod_size *= output_shape[dims - 2];
} else {
return;
}
}
int split_factor = target_native_vector_bits / type_bits;
if (prod_size == split_factor) {
stage->Vectorize(fused, split_factor);
return;
if (prod_size <= split_factor) {
split_factor = GetBetterSplitFactor(prod_size, split_factor);
if (split_factor >= 8) {
stage->Vectorize(fused, split_factor);
}
} else {
auto [j_outer, j_inner] = stage->Split(fused, split_factor);
stage->Vectorize(j_inner, split_factor);
}
auto [j_outer, j_inner] = stage->Split(fused, split_factor);
stage->Vectorize(j_inner, split_factor);
}

return;
}

Expand Down
17 changes: 13 additions & 4 deletions cinn/optim/vectorize_loops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ class Vectorizer : public IRMutator<Expr *> {
}

template <typename T>
Expr BinaryOperatorVec(const T *op, Expr *expr) {
void BinaryOperatorVec(const T *op, Expr *expr) {
auto *node = expr->As<T>();
Expr a0 = node->a();
Expr b0 = node->b();
Expand All @@ -323,7 +323,7 @@ class Vectorizer : public IRMutator<Expr *> {
// if (a0.same_as(node->a()) && b0.same_as(node->b())) return *expr;

int lanes = std::max(node->a().type().lanes(), node->b().type().lanes());
return T::Make(Widen(node->a(), lanes), Widen(node->b(), lanes));
*expr = T::Make(Widen(node->a(), lanes), Widen(node->b(), lanes));
}
};

Expand Down Expand Up @@ -521,8 +521,17 @@ struct VectorizeLoops_ : public IRMutator<Expr *> {
if (!for_min_i) return Expr();
if (for_min_i->value != 0) return Expr();

Expr times = common::AutoSimplify(Div::Make(forloop->extent, make_const(factor)));
Simplify(&times);
auto *extent_ptr = forloop->extent.As<IntImm>();
Expr times;
if (extent_ptr) {
int extent_int = forloop->extent.as_int32();
int extent_trunc = extent_int / factor;
int extent_times = extent_int % factor == 0 ? extent_trunc : extent_trunc + 1;
times = common::make_const(forloop->extent->type(), extent_times);
} else {
times = common::AutoSimplify(Div::Make(forloop->extent, make_const(factor)));
Simplify(&times);
}

// update the current forloop
auto times_int = times.As<IntImm>();
Expand Down
8 changes: 5 additions & 3 deletions cinn/optim/vectorize_loops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ void matmul(void* _args, int32_t num_args)
const float* B = ((const float*)(_B->memory));
float* C = ((float*)(_C->memory));
for (int32_t i = 0; i < 100; i += 1) {
for (int32_t j = 0; j < 31; j += 1) {
for (int32_t j = 0; j < 32; j += 1) {
C[StackVec<16,int32_t>::Ramp(((500 * i) + (16 * j)), 1, 16)] = (StackedVec<float,16>::Load(A,((500 * i) + (16 * j))) * StackedVec<float,16>::Load(B,((500 * i) + (16 * j))));
};
};
Expand Down Expand Up @@ -136,11 +136,13 @@ void matmul(const struct cinn_buffer_t *_A, const struct cinn_buffer_t *_B, stru
float* D = (float*)(_C->memory);
for (int32_t i = 0; i < 100; i += 1) {
for (int32_t j_outer = 0; j_outer < 31; j_outer += 1) {
C[StackVec<16,int32_t>::Ramp(((500 * i) + (16 * j_outer)), 1, 16)] = (StackedVec<float,16>::Load(A,((500 * i) + (16 * j_outer))) * StackedVec<float,16>::Load(B,((500 * i) + (16 * j_outer))));
C[StackVec<16,int32_t>::Ramp(((500 * i) + (16 * j_outer)), 1, 16)] = (StackedVec<float,16>::Load(A,((500 * i) +
(16 * j_outer))) * StackedVec<float,16>::Load(B,((500 * i) + (16 * j_outer))));
};
for (int32_t j_outer = 31; j_outer < 32; j_outer += 1) {
for (int32_t j_inner = 0; j_inner < (500 + (-16 * j_outer)); j_inner += 1) {
C[((500 * i) + ((16 * j_outer) + j_inner))] = (A[((500 * i) + ((16 * j_outer) + j_inner))] * B[((500 * i) + ((16 * j_outer) + j_inner))]);
C[((500 * i) + ((16 * j_outer) + j_inner))] = (A[((500 * i) + ((16 * j_outer) + j_inner))] * B[((500 * i) +
((16 * j_outer) + j_inner))]);
};
};
};
Expand Down
2 changes: 1 addition & 1 deletion python/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def to_test_op(self, input_shapes, output_shapes, op_name, attrs):
temp_inputs = []
alignment = 0
if self.target.arch == common.Target.Arch.X86:
alignment = 512
alignment = 32
for in_data in inputs_data:
temp_inputs.append(
runtime.cinn_buffer_t(in_data, runtime.cinn_x86_device,
Expand Down
4 changes: 2 additions & 2 deletions tests/benchmark/test_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,14 +103,14 @@ Module OpBenchmarkTester::CreateCinnModule(const std::vector<Tensor>& input_tens
void OpBenchmarkTester::CreateBuffer() {
std::vector<cinn_pod_value_t> args;
for (size_t i = 0; i < input_shapes_.size(); i++) {
auto* buffer = common::BufferBuilder(input_types_[i], input_shapes_[i]).set_align(512).set_random().Build();
auto* buffer = common::BufferBuilder(input_types_[i], input_shapes_[i]).set_align(32).set_random().Build();
cinn_pod_value_t arg(buffer);
all_args_.push_back(arg);
}
CHECK(!output_shapes_.empty()) << "output shapes shouldn't be empty\n";
CHECK_EQ(output_shapes_.size(), out_types_.size());
for (size_t i = 0; i < output_shapes_.size(); i++) {
auto* buffer = common::BufferBuilder(out_types_[i], output_shapes_[i]).set_align(512).set_zero().Build();
auto* buffer = common::BufferBuilder(out_types_[i], output_shapes_[i]).set_align(32).set_zero().Build();
CHECK(buffer);
out_dims_ = buffer->num_elements();
cinn_pod_value_t arg(buffer);
Expand Down

0 comments on commit 7fb48cc

Please sign in to comment.