Skip to content

Commit

Permalink
NPUW: Support asymmetric weights quantization (Opt B) (#26164)
Browse files Browse the repository at this point in the history
### Details:
There has a case emerged where now there are vectors of zero points
instead of a single value. Adding the support for unpacking such
asymmetric zeropoints.


### Tickets:
 - *134720*

---------

Co-authored-by: Dmitry Matveev <[email protected]>
  • Loading branch information
ujjayant-kadian and dmatveev authored Aug 23, 2024
1 parent d3096af commit 5b99ebc
Show file tree
Hide file tree
Showing 4 changed files with 343 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1627,6 +1627,9 @@ void Partitioner::decompressionCutOff(const std::string& func_name) {
// Phi-3 4SymW16A/GPTQ
rewr.add_matcher<ov::npuw::patterns::SymmZP::DCOFFPassCWAI3>(dcoff_mode, dcoff_type, std::ref(params_to));

// Asymmetric zeropoints
rewr.add_matcher<ov::npuw::patterns::AsymmZP::DCOFFPassReshape>(dcoff_mode, dcoff_type, std::ref(params_to));

rewr.run_on_model(f._model);

ov::pass::Validate val;
Expand Down
146 changes: 138 additions & 8 deletions src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/dcoff.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,36 +56,57 @@ ClosureRemap build_remap(const Function& fbody, const DCOFFParams& params_to) {

ClosureRemap m;

using PPtr = std::shared_ptr<ov::op::v0::Parameter>;
std::unordered_set<PPtr> ban_list;

for (const auto& scale_pair : params_to.scales) {
ban_list.insert(scale_pair.first);
}

for (const auto& zerop_pair : params_to.zerops_asymm) {
ban_list.insert(zerop_pair.second);
}

// FIXME: use indexed() here instead
for (std::size_t i = fbody._param_offset; i < body_params.size(); i++) {
LOG_DEBUG("Checking the function parameter " << body_params[i]);
const auto& param = body_params[i];
LOG_DEBUG("Checking the function parameter " << param);
LOG_BLOCK();

// First find among scale factors...
auto pscale_iter = params_to.scales.find(body_params[i]);
auto pscale_iter = params_to.scales.find(param);
auto pzerop_iter = params_to.zerops_asymm.find(param);
if (pscale_iter != params_to.scales.end()) {
LOG_DEBUG("This is a Scale factor parameter, will be removed");
auto& pscale_weight_param = pscale_iter->second;
auto pscale_weight_pindex = fbody._model->get_parameter_index(pscale_weight_param);
auto pscale_weight_cindex = pscale_weight_pindex - fbody._param_offset;
m.scale_remap[pscale_weight_cindex] = i - fbody._param_offset;
m.params_to_remove.push_back(body_params[i]);
} else {
m.params_to_remove.push_back(param);
} else if (pzerop_iter != params_to.zerops_asymm.end()) {
LOG_DEBUG("There is an Asymmetric zero point corresponding to this parameter, it will be removed");
auto zerop_pindex = fbody._model->get_parameter_index(pzerop_iter->second);
auto zerop_cindex = zerop_pindex - fbody._param_offset;
m.zerop_remap[i - fbody._param_offset] = zerop_cindex;
m.params_to_remove.push_back(pzerop_iter->second);
m.closure_remap.push_back(i - fbody._param_offset);
} else if (ban_list.find(param) == ban_list.end()) {
// If it's not in the ban list, it's an OK parameter and should be kept
LOG_DEBUG("This is an OK parameter, will be kept");
// n++ is the index of `i` here
m.closure_remap.push_back(i - fbody._param_offset);
}

// Process zero points for parameters
auto zerop_iter = params_to.zerops.find(body_params[i]);
auto zerop_iter = params_to.zerops.find(param);
if (zerop_iter != params_to.zerops.end()) {
LOG_DEBUG("This parameter requires zero point: " << zerop_iter->second);
m.zero_points.push_back(ov::npuw::util::tensor_from_const(zerop_iter->second));
} else {
m.zero_points.push_back(ov::Tensor());
}
}
NPUW_ASSERT((body_params.size() - fbody._param_offset) == (m.scale_remap.size() + m.closure_remap.size()));
NPUW_ASSERT((body_params.size() - fbody._param_offset) ==
(m.scale_remap.size() + m.closure_remap.size() + m.zerop_remap.size()));
NPUW_ASSERT((body_params.size() - fbody._param_offset) == m.zero_points.size());

LOG_DEBUG("DONE");
Expand All @@ -105,7 +126,10 @@ void apply_remap(Subgraph& fcall, const ClosureRemap& m) {

auto scale_iter = m.scale_remap.find(i);
new_scales.push_back(scale_iter != m.scale_remap.end() ? fcall._closure[scale_iter->second] : ov::Tensor());
new_zerops.push_back(m.zero_points[i]);
// Check for asymmetric zero points and add them to new_zerops
auto zerop_iter = m.zerop_remap.find(i);
const auto& zerop = zerop_iter != m.zerop_remap.end() ? fcall._closure[zerop_iter->second] : m.zero_points[i];
new_zerops.push_back(zerop);
}
fcall._closure = std::move(new_closure);
fcall._scales = std::move(new_scales);
Expand Down Expand Up @@ -790,6 +814,112 @@ CWAI3::CWAI3(CWAI3::Results scales) {
// Implementation TBD

} // namespace SymmZP

//------------------------------------------------------------------------------
// Pattern: ASymmZP, weights with asymmetric quantization
//
namespace AsymmZP {
// As seen in asymmetric TinyLlama:
// Since it is ASymm, all zero points for all blocks have different
// values so they will be Parameters but not Constants.
//
// In the diagram below, pattern on the left is identified and
// is modified to pattern in the right if type is promoted to f16
//
// "tensor" "zero point" "scale"
// Parameter:A Parameter:B Parameter:C > Parameter:A
// u4 u4 f16 > f16 <Const>
// : : : > : :
// V : : > V V
// Convert Convert : > Reshape|Convert
// f16 f16 : >
// : : : >
// V V : >
// Subtract : >
// f16 : >
// : : >
// V V >
// Multiply >
// fp16 <Const> >
// : : >
// V V >
// Reshape|Convert >
// : >
// V >
//
DCOFFPassReshape::DCOFFPassReshape(DCOffMode dcoff_mode, ov::element::Type dcoff_type, DCOFFParamRef pref) {
auto paramA = opp::wrap_type<ov::op::v0::Parameter>();
auto paramB = opp::wrap_type<ov::op::v0::Parameter>();
auto paramC = opp::wrap_type<ov::op::v0::Parameter>();
auto cvtA = opp::wrap_type<ov::op::v0::Convert>({paramA});
auto cvtB = opp::wrap_type<ov::op::v0::Convert>({paramB});
auto subtr = opp::wrap_type<ov::op::v1::Subtract>({cvtA, cvtB});
auto mulply = opp::wrap_type<ov::op::v1::Multiply>({subtr, paramC});

auto scalar = opp::wrap_type<ov::op::v0::Constant>();
auto reshpe = opp::wrap_type<ov::op::v1::Reshape>({mulply, scalar});

auto callback = [=](ov::pass::pattern::Matcher& m) {
auto& node_to_output = m.get_pattern_value_map();
auto matched_nodeA = node_to_output.at(paramA).get_node_shared_ptr();
auto matched_nodeB = node_to_output.at(paramB).get_node_shared_ptr();
auto matched_nodeC = node_to_output.at(paramC).get_node_shared_ptr();

NPUW_ASSERT(ov::op::util::is_parameter(matched_nodeA));
NPUW_ASSERT(ov::op::util::is_parameter(matched_nodeB));
NPUW_ASSERT(ov::op::util::is_parameter(matched_nodeC));

auto matched_paramA = std::static_pointer_cast<ov::op::v0::Parameter>(matched_nodeA);
auto matched_paramB = std::static_pointer_cast<ov::op::v0::Parameter>(matched_nodeB);
auto matched_paramC = std::static_pointer_cast<ov::op::v0::Parameter>(matched_nodeC);

if (ov::element::u4 == matched_paramA->get_element_type() &&
ov::element::u4 == matched_paramB->get_element_type() &&
ov::element::f16 == matched_paramC->get_element_type()) {
LOG_DEBUG("Matched: " << matched_paramA << ", set element type to " << dcoff_type);
matched_paramA->set_element_type(dcoff_type);

if (dcoff_mode == DCOffMode::CAST_SCALE) {
NPUW_ASSERT(dcoff_type == ov::element::f16);

LOG_DEBUG("Matched: " << matched_paramB << " - value to remove...");
LOG_DEBUG("Matched: " << matched_paramC << " - parameter to remove...");
LOG_BLOCK();

// Extra transformation here:
// - remove Subtract + Multiply,
// - mark paramC for removal.
// Reshape will be reconnected to ParamA directly

// Record mapping from the Scale coeff parameter to the Real weight parameter
pref.get().zerops_asymm[matched_paramA] = matched_paramB;
pref.get().scales[matched_paramC] = matched_paramA;

// Disconnect Multiply and Convert from their outputs
auto matched_mulply = node_to_output.at(mulply).get_node_shared_ptr();
auto matched_convrt = node_to_output.at(cvtA).get_node_shared_ptr();
auto drop_outputs = [](std::shared_ptr<ov::Node> node) {
for (auto&& node_outputs : node->outputs()) {
for (auto&& node_reader_port : node_outputs.get_target_inputs()) {
node_outputs.remove_target_input(node_reader_port);
}
}
};
LOG_DEBUG("Dropping the connections...");
drop_outputs(matched_mulply);
drop_outputs(matched_convrt);

LOG_DEBUG("Reconnecting the Root...");
auto matched_reshpe = node_to_output.at(reshpe).get_node_shared_ptr();
matched_reshpe->input(0).replace_source_output(matched_paramA);
}
LOG_DEBUG("Done");
}
return false; // root node hasn't changed
};
register_matcher(std::make_shared<opp::Matcher>(reshpe, "TagDCOFFReshape"), std::move(callback));
}
} // namespace AsymmZP
} // namespace patterns
} // namespace npuw
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,17 @@ namespace patterns {
struct DCOFFParams {
using PPtr = std::shared_ptr<ov::op::v0::Parameter>;
using CPtr = std::shared_ptr<ov::op::v0::Constant>;
std::unordered_map<PPtr, PPtr> scales; // Closures: a scaling factor -> orig tensor
std::unordered_map<PPtr, CPtr> zerops; // Closures: orig tensor -> a zero point (yes, a reverse...)
std::unordered_map<PPtr, PPtr> scales; // Closures: a scaling factor -> orig tensor
std::unordered_map<PPtr, CPtr> zerops; // Closures: orig tensor -> a zero point (yes, a reverse...)
std::unordered_map<PPtr, PPtr> zerops_asymm; // Closures: orig tensor -> an asymmetric zerop parameter
};

using DCOFFParamRef = std::reference_wrapper<DCOFFParams>;

struct ClosureRemap {
std::vector<std::size_t> closure_remap; // [new closure index] -> orig closure idx
std::map<std::size_t, std::size_t> scale_remap; // orig closure idx -> orig scale idx
std::map<std::size_t, std::size_t> zerop_remap; // orig closure idx -> orig asymm zero point idx
ov::ParameterVector params_to_remove;

std::vector<ov::Tensor> zero_points; // zero points for closures, if needed
Expand Down Expand Up @@ -160,6 +162,14 @@ class CWAI3 : public ov::pass::MatcherPass {

} // namespace SymmZP

namespace AsymmZP {
class DCOFFPassReshape : public ov::pass::MatcherPass {
public:
DCOFFPassReshape(DCOffMode dcoff_mode, ov::element::Type dcoff_type, DCOFFParamRef pref);
};

} // namespace AsymmZP

} // namespace patterns
} // namespace npuw
} // namespace ov
Loading

0 comments on commit 5b99ebc

Please sign in to comment.