Skip to content

Commit

Permalink
Review comments applied
Browse files Browse the repository at this point in the history
  • Loading branch information
v-Golubev committed Oct 10, 2023
1 parent 495af60 commit 6f86970
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 33 deletions.
8 changes: 4 additions & 4 deletions src/plugins/intel_cpu/src/graph_optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -310,8 +310,6 @@ void GraphOptimizer::FuseFCAndWeightsDecompression(Graph &graph) {
const bool withReshape = parent->getType() == Type::Reshape;
const auto reshapeNode = withReshape ? parent : nullptr;
if (reshapeNode) {
if (reshapeNode->getInputShapeAtPort(0).getRank() != 3 && reshapeNode->getOutputShapeAtPort(0).getRank() != 2)
continue;
parent = reshapeNode->getParentEdgesAtPort(0)[0]->getParent();
}

Expand Down Expand Up @@ -358,6 +356,8 @@ void GraphOptimizer::FuseFCAndWeightsDecompression(Graph &graph) {
const auto weightsShape = weightsNode->getOutputShapeAtPort(0);
if (weightsShape != multiplyNode->getOutputShapeAtPort(0))
continue;
if (reshapeNode && (reshapeNode->getInputShapeAtPort(0).getRank() != 3 || reshapeNode->getOutputShapeAtPort(0).getRank() != 2))
continue;

VectorDims decompressionConstShape;
const auto fcInputWeightsShape = fcNode->getInputShapeAtPort(1);
Expand Down Expand Up @@ -399,9 +399,9 @@ void GraphOptimizer::FuseFCAndWeightsDecompression(Graph &graph) {
}

// Fusion processing
fcNode->fuseDecompressionMultiplyPtr(multiplyConstNode);
fcNode->fuseDecompressionMultiply(multiplyConstNode);
if (withSubtract)
fcNode->fuseDecompressionSubtractPtr(subtractConstNode);
fcNode->fuseDecompressionSubtract(subtractConstNode);

fcNode->addOriginalLayer(multiplyNode->getOriginalLayers());
fcNode->addOriginalLayer(convertNode->getOriginalLayers());
Expand Down
25 changes: 19 additions & 6 deletions src/plugins/intel_cpu/src/nodes/fullyconnected.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1132,20 +1132,33 @@ bool FullyConnected::useSparseWeightsDecompression() {
return true;
}

void FullyConnected::fuseDecompressionMultiplyPtr(const NodePtr& constData) {
fuseDecompressionConstantPtr(constData, decompressionMultiplyPtr);
void FullyConnected::fuseDecompressionMultiply(const NodePtr& constData) {
fuseDecompressionConstant(constData, decompressionMultiplyPtr);
}

void FullyConnected::fuseDecompressionSubtractPtr(const NodePtr& constData) {
fuseDecompressionConstantPtr(constData, decompressionSubtractPtr);
void FullyConnected::fuseDecompressionSubtract(const NodePtr& constData) {
fuseDecompressionConstant(constData, decompressionSubtractPtr);
}

void FullyConnected::fuseDecompressionConstantPtr(const NodePtr& constData, MemoryCPtr& decompressionValuesPtr) {
void FullyConnected::fuseDecompressionConstant(const NodePtr& constData, MemoryCPtr& decompressionValuesPtr) {
auto *constInputNode = dynamic_cast<node::Input *>(constData.get());
if (!constInputNode) {
IE_THROW() << "Cannot cast " << constData->getName() << " to Input";
}
decompressionValuesPtr = constInputNode->getMemoryPtr();
const auto decompression_prc = InferenceEngine::Precision::FP32;
if (constInputNode->getOriginalOutputPrecisionAtPort(0) == decompression_prc) {
decompressionValuesPtr = constInputNode->getMemoryPtr();
} else {
const auto constBlob = constInputNode->getMemoryPtr();
DnnlBlockedMemoryDesc memoryDesc(decompression_prc, constBlob->getShape());
decompressionValuesPtr = std::make_shared<Memory>(getEngine(), memoryDesc, nullptr, false);
const auto elementsCount = constBlob->getDescWithType<BlockedMemoryDesc>()->getPaddedElementsCount();
cpu_convert(constBlob->getData(),
decompressionValuesPtr->getData(),
DnnlExtensionUtils::DataTypeToIEPrecision(constBlob->getDataType()),
Precision::FP32,
elementsCount);
}
}

DnnlMemoryDescPtr FullyConnected::makeTransposedWeightDescriptor(DnnlMemoryDescPtr desc) {
Expand Down
6 changes: 3 additions & 3 deletions src/plugins/intel_cpu/src/nodes/fullyconnected.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ class FullyConnected : public Node {
this->weightsNonTransposed = weightsNonTransposed;
}

void fuseDecompressionMultiplyPtr(const NodePtr& constData);
void fuseDecompressionSubtractPtr(const NodePtr& constData);
void fuseDecompressionMultiply(const NodePtr& constData);
void fuseDecompressionSubtract(const NodePtr& constData);

private:
void createDescriptorInternal(const dnnl::memory::desc &inputDesc,
Expand Down Expand Up @@ -99,7 +99,7 @@ class FullyConnected : public Node {
const dnnl::engine& engine);

bool canBeExecutedInConv1x1() const;
void fuseDecompressionConstantPtr(const NodePtr& constData, MemoryCPtr& decompressionValuesPtr);
void fuseDecompressionConstant(const NodePtr& constData, MemoryCPtr& decompressionValuesPtr);

// sparse weights
bool useSparseWeights = false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,15 @@ namespace SubgraphTestsDefinitions {

struct ShapeParams {
ShapeParams() = default;
ShapeParams(InputShape data_shape, ov::Shape weights_shape, size_t weights_group_size = 1)
ShapeParams(InputShape data_shape, ov::Shape weights_shape, int weights_group_size = -1)
: data_shape(std::move(data_shape)),
weights_shape(std::move(weights_shape)),
weights_group_size(weights_group_size) {}

InputShape data_shape;
ov::Shape weights_shape;
size_t weights_group_size;
// Decompression group size. If the value is equal to -1, ordinary decompression is used
int weights_group_size;
};
using MatmulWeightsDecompressionParams = std::tuple<ShapeParams,
ov::test::ElementType, // weights precision
Expand Down Expand Up @@ -96,7 +97,7 @@ class MatmulWeightsDecompression : public testing::WithParamInterface<MatmulWeig

protected:
std::shared_ptr<ov::Node> initDecompressionWeights(const ov::Shape& weights_shape,
const size_t group_size,
const int group_size,
const ov::element::Type data_precision,
const ov::element::Type weights_precision,
const bool transpose_weights,
Expand All @@ -109,23 +110,23 @@ class MatmulWeightsDecompression : public testing::WithParamInterface<MatmulWeig
return result_shape;
};

const bool group_decompression = group_size != -1;
// Weights has shape [I, O], where
// I - input channels
// O - output channels
// If group size greater than 1, input channels dimension are split into 2: I -> [N, G], where
// In case of group decompression, input channels dimension is split into 2: I -> [N, G], where
// N - number of groups
// G - group size
auto transformed_weights_shape = transpose_if_necessary(weights_shape);
OPENVINO_ASSERT(weights_shape[0] % group_size == 0,
"Weights output channels count (",
weights_shape[0],
") must be divisible by decompression group size (",
group_size,
").");
const size_t number_of_groups = weights_shape[0] / group_size;
if (group_size > 1) {
if (group_decompression) {
OPENVINO_ASSERT(weights_shape[0] % group_size == 0,
"Weights output channels count (",
weights_shape[0],
") must be divisible by decompression group size (",
group_size,
").");
auto in_channel_idx = transpose_weights ? transformed_weights_shape.size() - 1 : transformed_weights_shape.size() - 2;
transformed_weights_shape[in_channel_idx] = number_of_groups;
transformed_weights_shape[in_channel_idx] = weights_shape[0] / group_size;
transformed_weights_shape.insert(transformed_weights_shape.begin() + in_channel_idx + 1, group_size);
}
auto weights = ngraph::builder::makeConstant<uint8_t>(weights_precision, transformed_weights_shape, {}, true);
Expand All @@ -136,12 +137,12 @@ class MatmulWeightsDecompression : public testing::WithParamInterface<MatmulWeig
auto output_channels = *weights_shape.rbegin();

// Decompression constants shape:
// if group size = 1: [O, 1]
// otherwise: [O, N, 1]
// Ordinary decompression: [O, 1]
// Group decompression: [O, N, 1]
ov::Shape scaleshift_target_shape{output_channels};
scaleshift_target_shape.insert(scaleshift_target_shape.begin(), group_size == 1 ? 1 : number_of_groups);
scaleshift_target_shape.insert(scaleshift_target_shape.begin(), group_decompression ? weights_shape[0] / group_size : 1);
scaleshift_target_shape = transpose_if_necessary(scaleshift_target_shape);
if (group_size > 1) {
if (group_decompression) {
auto in_channel_idx = transpose_weights ? scaleshift_target_shape.size() - 1 : scaleshift_target_shape.size() - 2;
scaleshift_target_shape.insert(scaleshift_target_shape.begin() + in_channel_idx + 1, 1);
}
Expand All @@ -168,7 +169,7 @@ class MatmulWeightsDecompression : public testing::WithParamInterface<MatmulWeig
}
std::shared_ptr<ov::Node> last_node = std::make_shared<ov::opset10::Multiply>(mul_parent, scale_const);

if (group_size > 1) {
if (group_decompression) {
auto reshape_target_shape = transpose_weights ? std::vector<int>{-1, static_cast<int>(weights_shape[0])}
: std::vector<int>{static_cast<int>(weights_shape[0]), -1};
auto target_shape_node = ov::opset10::Constant::create(ov::element::i32, {reshape_target_shape.size()}, reshape_target_shape);
Expand All @@ -187,7 +188,7 @@ class MatmulWeightsDecompression : public testing::WithParamInterface<MatmulWeig

std::shared_ptr<ov::Model> initSubgraph(const ov::PartialShape& data_shape,
const ov::Shape& weights_shape,
const size_t group_size,
const int group_size,
const ov::element::Type data_precision,
const ov::element::Type weights_precision,
const bool transpose_weights,
Expand Down Expand Up @@ -247,7 +248,7 @@ class MatmulWeightsDecompression : public testing::WithParamInterface<MatmulWeig
const auto& test_param = GetParam();
const auto& weights_precision = std::get<1>(test_param);
// TODO: remove this condition when group decompression is supported
if (weights_precision == ov::element::nf4 || std::get<0>(test_param).weights_group_size != 1) {
if (weights_precision == ov::element::nf4 || std::get<0>(test_param).weights_group_size != -1) {
return;
}
bool weights_found = false;
Expand Down

0 comments on commit 6f86970

Please sign in to comment.