Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add quant_params_count_limit to VisualizeConfig #147

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 23 additions & 13 deletions src/builtin-adapter/direct_flatbuffer_to_json_graph_convert.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ limitations under the License.

#include "direct_flatbuffer_to_json_graph_convert.h"

#include <algorithm>
#include <cstddef>
#include <cstdint>
#include <cstring>
#include <memory>
Expand Down Expand Up @@ -649,6 +651,7 @@ absl::Status AddTensorTags(const OperatorT& op, absl::string_view op_label,
}

void AddQuantizationParameters(const std::unique_ptr<TensorT>& tensor,
const size_t size_limit,
const EdgeType edge_type, const int rel_idx,
GraphNodeBuilder& builder) {
if (tensor->quantization == nullptr) return;
Expand All @@ -662,9 +665,13 @@ void AddQuantizationParameters(const std::unique_ptr<TensorT>& tensor,
}
if (quant->scale.empty()) return;

const unsigned num_params = (size_limit < 0)
? quant->scale.size()
: std::min(quant->scale.size(), size_limit);
if (num_params == 0) return;
std::vector<std::string> parameters;
parameters.reserve(quant->scale.size());
for (int i = 0; i < quant->scale.size(); ++i) {
parameters.reserve(num_params);
for (int i = 0; i < num_params; ++i) {
// Parameters will be shown as "[scale] * (q + [zero_point])"
parameters.push_back(
absl::StrCat(quant->scale[i], " * (q + ", quant->zero_point[i], ")"));
Expand All @@ -680,7 +687,7 @@ absl::Status AddNode(
const Buffers& buffers, const std::vector<std::string>& func_names,
const std::optional<const SignatureNameMap>& signature_name_map,
const OpdefsMap& op_defs, const std::unique_ptr<FlatBufferModel>& model_ptr,
const int const_element_count_limit, std::vector<std::string>& node_ids,
const VisualizeConfig& config, std::vector<std::string>& node_ids,
EdgeMap& edge_map, mlir::Builder mlir_builder, Subgraph& subgraph) {
if (op.opcode_index >= op_names.size()) {
return absl::InvalidArgumentError(
Expand Down Expand Up @@ -713,14 +720,16 @@ absl::Status AddNode(
// when the input tensor is constant and not an output of a node. Thus we
// create an auxiliary constant node to align with graph structure.
if (EdgeInfoIncomplete(edge_map.at(tensor_index))) {
RETURN_IF_ERROR(AddAuxiliaryNode(
NodeType::kConstNode, std::vector<int>{tensor_index}, tensors,
buffers, signature_name_map, model_ptr, const_element_count_limit,
node_ids, edge_map, mlir_builder, subgraph));
RETURN_IF_ERROR(
AddAuxiliaryNode(NodeType::kConstNode, std::vector<int>{tensor_index},
tensors, buffers, signature_name_map, model_ptr,
config.const_element_count_limit, node_ids, edge_map,
mlir_builder, subgraph));
}
AppendIncomingEdge(edge_map.at(tensor_index), builder);
AddQuantizationParameters(tensors[tensor_index], EdgeType::kInput, i,
builder);
AddQuantizationParameters(tensors[tensor_index],
config.quant_params_count_limit, EdgeType::kInput,
i, builder);
}

for (int i = 0; i < op.outputs.size(); ++i) {
Expand All @@ -732,8 +741,9 @@ absl::Status AddNode(
.source_node_output_id = absl::StrCat(i)},
edge_map);

AddQuantizationParameters(tensors[tensor_index], EdgeType::kOutput, i,
builder);
AddQuantizationParameters(tensors[tensor_index],
config.quant_params_count_limit,
EdgeType::kOutput, i, builder);
}

status = AddTensorTags(op, node_label, op_defs, builder);
Expand Down Expand Up @@ -802,8 +812,8 @@ absl::Status AddSubgraph(
const Tensors& tensors = subgraph_t.tensors;
RETURN_IF_ERROR(AddNode(i, *op, op_codes, op_names, tensors, buffers,
func_names, signature_name_map, op_defs, model_ptr,
config.const_element_count_limit, node_ids,
edge_map, mlir_builder, subgraph));
config, node_ids, edge_map, mlir_builder,
subgraph));
}

// Adds GraphOutputs node to the subgraph.
Expand Down
9 changes: 9 additions & 0 deletions src/builtin-adapter/models_to_json_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ constexpr char kInputFileFlag[] = "i";
constexpr char kOutputFileFlag[] = "o";
constexpr char kConstElementCountLimitFlag[] = "const_element_count_limit";
constexpr char kDisableMlirFlag[] = "disable_mlir";
constexpr char kQuantParamsCountLimitFlag[] = "quant_params_count_limit";

namespace {

Expand All @@ -42,6 +43,7 @@ int main(int argc, char* argv[]) {
// Creates and parses flags.
std::string input_file, output_file;
int const_element_count_limit = 16;
int quant_params_count_limit = 16;
bool disable_mlir = false;

std::vector<mlir::Flag> flag_list = {
Expand All @@ -61,6 +63,12 @@ int main(int argc, char* argv[]) {
"Disable the MLIR-based conversion. If set to true, the conversion "
"becomes from model directly to graph json",
mlir::Flag::kOptional),
mlir::Flag::CreateFlag(
kQuantParamsCountLimitFlag, &quant_params_count_limit,
"The maximum number of quant parameters. If the number exceeds this "
"threshold, the rest of data will be elided. If the flag is not set, "
"the default threshold is 16 (use -1 to print all)",
mlir::Flag::kOptional),
};
mlir::Flags::Parse(&argc, const_cast<const char**>(argv), flag_list);

Expand All @@ -76,6 +84,7 @@ int main(int argc, char* argv[]) {
// Creates visualization config.
tooling::visualization_client::VisualizeConfig config(
const_element_count_limit);
config.quant_params_count_limit = quant_params_count_limit;

const absl::StatusOr<std::string> json_output =
ConvertModelToJson(config, input_file, disable_mlir);
Expand Down
5 changes: 5 additions & 0 deletions src/builtin-adapter/visualize_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ struct VisualizeConfig {
// exceeds this threshold, the rest of data will be elided. The default
// threshold is set to 16 (use -1 to print all).
int const_element_count_limit = 16;

// The maximum number of quantization parameters to be displayed. If the
// number exceeds this threshold, the rest of data will be elided. The default
// threshold is set to 16 (use -1 to print all).
int quant_params_count_limit = 16;
};

} // namespace visualization_client
Expand Down