From 8f27b9d7bc811f76fdccb55677c46c76f6ec2eaf Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Fri, 13 Oct 2017 15:27:41 -0700 Subject: [PATCH 1/3] [CODEGEN] Detect broadcast(cast(x)) pattern in FMA --- src/pass/lower_intrin.cc | 42 +++++++++++++++++++++++++++++++--------- 1 file changed, 33 insertions(+), 9 deletions(-) diff --git a/src/pass/lower_intrin.cc b/src/pass/lower_intrin.cc index 84f3cb5bba6f..bce753dfc107 100644 --- a/src/pass/lower_intrin.cc +++ b/src/pass/lower_intrin.cc @@ -34,22 +34,46 @@ class IntrinInjecter : public IRMutator { } Expr Mutate_(const Add* op, const Expr& e) final { - if (fma_ == nullptr || !op->type.is_float()) { - return IRMutator::Mutate_(op, e); - } if (const Mul* mb = op->b.as()) { - Expr r = (*fma_)(Call::make( - op->type, "fma", {mb->a, mb->b, op->a}, Call::PureIntrinsic)); - if (r.defined()) return this->Mutate(r); + Expr lhs = SwapBroadcastCast(mb->a); + Expr rhs = SwapBroadcastCast(mb->b); + + if (fma_ != nullptr && op->type.is_float()) { + Expr r = (*fma_)(Call::make( + op->type, "fma", {lhs, rhs, op->a}, Call::PureIntrinsic)); + if (r.defined()) return this->Mutate(r); + } else { + Expr a = this->Mutate(op->a); + Expr b = this->Mutate(Mul::make(lhs, rhs)); + return Add::make(a, b); + } } else if (const Mul* ma = op->a.as()) { - Expr r = (*fma_)(Call::make( - op->type, "fma", {ma->a, ma->b, op->b}, Call::PureIntrinsic)); - if (r.defined()) return this->Mutate(r); + Expr lhs = SwapBroadcastCast(ma->a); + Expr rhs = SwapBroadcastCast(ma->b); + + if (fma_ != nullptr && op->type.is_float()) { + Expr r = (*fma_)(Call::make( + op->type, "fma", {lhs, rhs, op->b}, Call::PureIntrinsic)); + if (r.defined()) return this->Mutate(r); + } else { + Expr a = this->Mutate(Mul::make(lhs, rhs)); + Expr b = this->Mutate(op->b); + return Add::make(a, b); + } } return IRMutator::Mutate_(op, e); } private: + Expr SwapBroadcastCast(Expr e) { + if (const Broadcast* bcast = e.as()) { + if (const Cast* cast = bcast->value.as()) { + Expr new_bcast = Broadcast::make(cast->value, bcast->lanes); + return Cast::make(cast->type, new_bcast); + } + } + return e; + } Expr ApplyPattern(const std::string& name, const Expr& e) { for (size_t i = 0; i < patterns_.size(); ++i) { std::string& p = patterns_[i]; From 3c98e80d4e7cf3df0759351f8863053449e6ba3c Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Fri, 13 Oct 2017 17:11:25 -0700 Subject: [PATCH 2/3] [CODEGEN] Improve --- src/pass/lower_intrin.cc | 58 +++++++++++++++++++++------------------- 1 file changed, 31 insertions(+), 27 deletions(-) diff --git a/src/pass/lower_intrin.cc b/src/pass/lower_intrin.cc index bce753dfc107..7fa3219ce07b 100644 --- a/src/pass/lower_intrin.cc +++ b/src/pass/lower_intrin.cc @@ -35,45 +35,49 @@ class IntrinInjecter : public IRMutator { Expr Mutate_(const Add* op, const Expr& e) final { if (const Mul* mb = op->b.as()) { - Expr lhs = SwapBroadcastCast(mb->a); - Expr rhs = SwapBroadcastCast(mb->b); - - if (fma_ != nullptr && op->type.is_float()) { - Expr r = (*fma_)(Call::make( - op->type, "fma", {lhs, rhs, op->a}, Call::PureIntrinsic)); - if (r.defined()) return this->Mutate(r); - } else { - Expr a = this->Mutate(op->a); - Expr b = this->Mutate(Mul::make(lhs, rhs)); - return Add::make(a, b); - } + return MakeFMA(mb->a, mb->b, op->a, op, e); } else if (const Mul* ma = op->a.as()) { - Expr lhs = SwapBroadcastCast(ma->a); - Expr rhs = SwapBroadcastCast(ma->b); - - if (fma_ != nullptr && op->type.is_float()) { - Expr r = (*fma_)(Call::make( - op->type, "fma", {lhs, rhs, op->b}, Call::PureIntrinsic)); - if (r.defined()) return this->Mutate(r); - } else { - Expr a = this->Mutate(Mul::make(lhs, rhs)); - Expr b = this->Mutate(op->b); - return Add::make(a, b); - } + return MakeFMA(ma->a, ma->b, op->b, op, e); } return IRMutator::Mutate_(op, e); } private: - Expr SwapBroadcastCast(Expr e) { + Expr SwapBroadcastCast(const Expr& e) { + // Try to change broadcast(cast(x)) to cast(broadcast(x)) + // For some targets, LLVM will generate more efficient FMA + // instruction with the latter. For example, vmla vs. vmlal + // on ARM. if (const Broadcast* bcast = e.as()) { if (const Cast* cast = bcast->value.as()) { - Expr new_bcast = Broadcast::make(cast->value, bcast->lanes); - return Cast::make(cast->type, new_bcast); + if (cast->type == cast->value.type().with_bits(cast->value.type().bits() * 2)) { + Expr new_bcast = Broadcast::make(cast->value, bcast->lanes); + return Cast::make(bcast->type, new_bcast); + } } } return e; } + + Expr MakeFMA(const Expr& a, const Expr& b, const Expr& c, + const Add* op, const Expr& e) { + // emit fma instruction: a * b + c + Expr lhs = SwapBroadcastCast(a); + Expr rhs = SwapBroadcastCast(b); + + if (fma_ != nullptr && op->type.is_float()) { + Expr r = (*fma_)(Call::make( + op->type, "fma", {lhs, rhs, op->b}, Call::PureIntrinsic)); + if (r.defined()) return this->Mutate(r); + } else { + if (!(lhs.same_as(a) && rhs.same_as(b))) { + Expr mul = this->Mutate(Mul::make(lhs, rhs)); + return Add::make(mul, this->Mutate(c)); + } + } + return IRMutator::Mutate_(op, e); + } + Expr ApplyPattern(const std::string& name, const Expr& e) { for (size_t i = 0; i < patterns_.size(); ++i) { std::string& p = patterns_[i]; From fc914dfb80885d477eb80af748e3afaa47f4a42b Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Fri, 13 Oct 2017 17:50:16 -0700 Subject: [PATCH 3/3] [CODEGEN] Fix --- src/pass/lower_intrin.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/pass/lower_intrin.cc b/src/pass/lower_intrin.cc index 7fa3219ce07b..33ac6a94ecf7 100644 --- a/src/pass/lower_intrin.cc +++ b/src/pass/lower_intrin.cc @@ -50,7 +50,7 @@ class IntrinInjecter : public IRMutator { // on ARM. if (const Broadcast* bcast = e.as()) { if (const Cast* cast = bcast->value.as()) { - if (cast->type == cast->value.type().with_bits(cast->value.type().bits() * 2)) { + if (cast->type.bits() == cast->value.type().bits() * 2) { Expr new_bcast = Broadcast::make(cast->value, bcast->lanes); return Cast::make(bcast->type, new_bcast); } @@ -67,10 +67,10 @@ class IntrinInjecter : public IRMutator { if (fma_ != nullptr && op->type.is_float()) { Expr r = (*fma_)(Call::make( - op->type, "fma", {lhs, rhs, op->b}, Call::PureIntrinsic)); + op->type, "fma", {lhs, rhs, c}, Call::PureIntrinsic)); if (r.defined()) return this->Mutate(r); } else { - if (!(lhs.same_as(a) && rhs.same_as(b))) { + if (!lhs.same_as(a) || !rhs.same_as(b)) { Expr mul = this->Mutate(Mul::make(lhs, rhs)); return Add::make(mul, this->Mutate(c)); }