Skip to content

Commit

Permalink
feat(supportedops): Application to dump a list of supported operators
Browse files Browse the repository at this point in the history
Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed Feb 17, 2021
1 parent 1c9dfe2 commit 872d9a3
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 0 deletions.
12 changes: 12 additions & 0 deletions core/conversion/converters/NodeConverterRegistry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class NodeConverterRegistry {
public:
bool RegisterConverter(torch::jit::FunctionSchema* signature, OpConverter& converter) {
LOG_DEBUG("Registering converter for " << canonical_schema_string(*signature));
registered_converter_schemas_.insert(c10::toString(*signature));
auto name = signature->operator_name();
auto iter = converter_lut_.find(name);
if (iter != converter_lut_.end()) {
Expand Down Expand Up @@ -83,8 +84,15 @@ class NodeConverterRegistry {
}
}

std::vector<std::string> GetRegisteredConverterList() {
std::vector<std::string> converter_list;
std::copy(registered_converter_schemas_.begin(), registered_converter_schemas_.end(), std::back_inserter(converter_list));
return converter_list;
}

private:
ConverterLUT converter_lut_;
std::set<std::string> registered_converter_schemas_;
};

NodeConverterRegistry& get_converter_registry() {
Expand Down Expand Up @@ -115,6 +123,10 @@ bool node_is_convertable(const torch::jit::Node* n) {
return get_converter_registry().Convertable(n);
}

std::vector<std::string> get_converter_list() {
return get_converter_registry().GetRegisteredConverterList();
}

RegisterNodeConversionPatterns&& RegisterNodeConversionPatterns::pattern(ConversionPattern p) && {
register_node_converter(std::move(p));
return std::move(*this);
Expand Down
1 change: 1 addition & 0 deletions core/conversion/converters/converters.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class RegisterNodeConversionPatterns {

bool node_is_convertable(const torch::jit::Node* n);
OpConverter get_node_converter_for(const torch::jit::FunctionSchema* signature);
std::vector<std::string> get_converter_list();

} // namespace converters
} // namespace conversion
Expand Down
14 changes: 14 additions & 0 deletions core/conversion/evaluators/NodeEvaluatorRegistry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ class NodeEvaluatorRegistry {
"Attempting to override already registered evaluator " << node_kind.toQualString()
<< ", merge implementations instead");
}
for (auto const& e : eval_reg.options.supported_variants) {
registered_evaluator_schemas_.insert(e);
}
evaluator_lut_[node_kind] = std::move(eval_reg);
}

Expand Down Expand Up @@ -76,6 +79,12 @@ class NodeEvaluatorRegistry {
return evaluator;
}

std::vector<std::string> GetRegisteredEvaluatorList() {
std::vector<std::string> evaluator_list;
std::copy(registered_evaluator_schemas_.begin(), registered_evaluator_schemas_.end(), std::back_inserter(evaluator_list));
return evaluator_list;
}

bool EvalAtConversionTime(const torch::jit::Node* n) {
auto evaluator = FindEvaluator(n);
if (evaluator == nullptr) {
Expand All @@ -87,6 +96,7 @@ class NodeEvaluatorRegistry {

private:
EvaluatorLUT evaluator_lut_;
std::set<std::string> registered_evaluator_schemas_;
};

NodeEvaluatorRegistry& get_evaluator_registry() {
Expand All @@ -99,6 +109,10 @@ bool shouldEvalAtConversionTime(const torch::jit::Node* n) {
return get_evaluator_registry().EvalAtConversionTime(n);
}

std::vector<std::string> getEvaluatorList() {
return get_evaluator_registry().GetRegisteredEvaluatorList();
}

c10::optional<torch::jit::IValue> EvalNode(const torch::jit::Node* n, kwargs& args) {
auto evaluator = get_evaluator_registry().GetEvaluator(n);
return evaluator(n, args);
Expand Down
3 changes: 3 additions & 0 deletions core/conversion/evaluators/evaluators.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,15 @@ typedef std::function<c10::optional<torch::jit::IValue>(const torch::jit::Node*,
struct EvalOptions {
std::set<c10::TypePtr> blacklisted_output_types;
std::vector<c10::OperatorName> valid_schemas;
std::vector<std::string> supported_variants;
EvalOptions() = default;
EvalOptions& blacklistOutputTypes(std::set<c10::TypePtr> types) {
use_options = true;
blacklisted_output_types = types;
return *this;
}
EvalOptions& validSchemas(std::set<std::string> schemas) {
std::copy(schemas.begin(), schemas.end(), std::back_inserter(supported_variants));
use_options = true;
for (auto s : schemas) {
valid_schemas.push_back(torch::jit::parseSchema(s).operator_name());
Expand Down Expand Up @@ -72,6 +74,7 @@ struct EvalRegistration {

c10::optional<torch::jit::IValue> EvalNode(const torch::jit::Node* n, kwargs& args);
bool shouldEvalAtConversionTime(const torch::jit::Node* n);
std::vector<std::string> getEvaluatorList();
void register_node_evaluator(torch::jit::NodeKind node_kind, NodeEvaluator evaluator);
void register_node_evaluator(EvalRegistration r);

Expand Down
12 changes: 12 additions & 0 deletions cpp/supportedops/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package(default_visibility = ["//visibility:public"])

cc_binary(
name = "supportedops",
srcs = [
"main.cpp"
],
deps = [
"//cpp/api:trtorch",
"//core/conversion/converters"
],
)
48 changes: 48 additions & 0 deletions cpp/supportedops/main.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#include "core/conversion/converters/converters.h"
#include "core/conversion/evaluators/evaluators.h"

#include <string>
#include <sstream>
#include <vector>
#include <iostream>

int main(int argc, const char* argv[]) {
std::vector<std::string> converters = trtorch::core::conversion::converters::get_converter_list();
std::vector<std::string> evaluators = trtorch::core::conversion::evaluators::getEvaluatorList();

std::stringstream ss;

ss << R"TITLE(
=================================
Operators Supported
=================================
)TITLE";

ss << R"SEC(
Operators Currently Supported Through Converters
-------------------------------------------------
)SEC";

for (auto c : converters) {
ss << "- " << c << std::endl;
}

ss << R"SEC(
Operators Currently Supported Through Evaluators
-------------------------------------------------
)SEC";

for (auto e : evaluators) {
ss << "- " << e << std::endl;
}

std::ofstream ofs;
ofs.open(argv[1]);

ofs << ss.rdbuf();

return 0;
}

0 comments on commit 872d9a3

Please sign in to comment.