Skip to content

Commit

Permalink
added missed code to prepareWeightMemory
Browse files Browse the repository at this point in the history
  • Loading branch information
alvoron committed Dec 16, 2024
1 parent 0079f5f commit 5455c1b
Showing 1 changed file with 11 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -199,9 +199,19 @@ MemoryPtr acl_fc_executor::prepareWeightMemory(const MemoryArgs &memory,
MemoryArgs memoryArgs;
memoryArgs[ARG_BIAS] = memory.at(ARG_BIAS);
memoryArgs[ARG_WEI] = memory.at(ARG_WEI);

auto originalWeightsDesc = memory.at(ARG_WEI)->getDescPtr();
// normalize weights to 2D
const auto& wgtDims = originalWeightsDesc->getShape().getStaticDims();
const VectorDims wgtDims2D = reshapeDownToRank<2>(wgtDims);
originalWeightsDesc = std::make_shared<CpuBlockedMemoryDesc>(originalWeightsDesc->getPrecision(), Shape{wgtDims2D});
auto dnnlSrcDesc = MemoryDescUtils::convertToDnnlMemoryDesc(originalWeightsDesc);
auto dstDesc = originalWeightsDesc->cloneWithNewPrecision(aclfcAttrs.inputPrecision);
auto dnnlDstDesc = MemoryDescUtils::convertToDnnlMemoryDesc(dstDesc);

if (memory.at(ARG_SRC_0)->getShape().isDynamic()) {
const auto& inShape = memory.at(ARG_SRC_0)->getShape();
const auto& wShape = memory.at(ARG_WEI)->getShape();
const auto& wShape = originalWeightsDesc->getShape();
const auto& inDymmyDims = makeDummyInputDims(inShape, wShape);
const auto& outDymmyDims = makeDummyOutputDims(inDymmyDims, wShape.getStaticDims(), memory.at(ARG_DST)->getShape().getRank());
memoryArgs[ARG_SRC_0] = std::make_shared<Memory>(context->getEngine(),
Expand All @@ -219,13 +229,6 @@ MemoryPtr acl_fc_executor::prepareWeightMemory(const MemoryArgs &memory,
expectedWeightFormat = isNeededReorder ? aclWeightsRepack->getOptImplWeightFormat() : arm_compute::WeightFormat::UNSPECIFIED;
weiTensorInfo = aclWeightsRepack->getTensorInfo(ACLArgs::ACL_WEI);

MemoryPtr dstMemPtr = std::make_shared<Memory>(context->getEngine(),
memory.at(ARG_WEI)->getDescPtr()->cloneWithNewPrecision(aclfcAttrs.inputPrecision));
auto dstDesc = dstMemPtr->getDescPtr();
auto dnnlDstDesc = MemoryDescUtils::convertToDnnlMemoryDesc(dstDesc);
auto weiDesc = memory.at(ARG_WEI)->getDescPtr();
auto dnnlSrcDesc = MemoryDescUtils::convertToDnnlMemoryDesc(weiDesc);

if (isNeededReorder) {
dnnl::impl::dim_t o_dim = 0;
dnnl::impl::dim_t inner_dim = 1;
Expand Down

0 comments on commit 5455c1b

Please sign in to comment.