Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LPT refactoring #17

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ class TRANSFORMATIONS_API AddTransformation : public EltwiseBaseTransformation {
public:
AddTransformation(const Params& params = Params());
~AddTransformation() override {}
void registerMatcherIn(GraphRewrite& pass, TransformationContext& context) const override;
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) const override;
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const override;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ class TRANSFORMATIONS_API AvgPoolTransformation : public LayerTransformation {
public:
AvgPoolTransformation();
AvgPoolTransformation(const Params& params);
void registerMatcherIn(GraphRewrite& pass, TransformationContext& context) const override;
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) const override;
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const override;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ namespace low_precision {
class TRANSFORMATIONS_API ClampTransformation : public LayerTransformation {
public:
ClampTransformation(const Params& params = Params());
void registerMatcherIn(GraphRewrite& pass, TransformationContext& context) const override;
bool transform(TransformationContext& context, ngraph::pattern::Matcher& m) const override;
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> op) const override;
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class OperationPrecisionRestriction {
using PrecisionsByPort = std::vector<std::pair<size_t, std::set<ngraph::element::Type>>>;

std::string name;
uint64_t version;
int64_t version;
std::vector<std::pair<size_t, std::set<ngraph::element::Type>>> precisionsByPort;

OperationPrecisionRestriction() = default;
Expand All @@ -45,7 +45,7 @@ class OperationPrecisionRestriction {
const PrecisionsByPort& precisionsByPort,
const bool specifiedVersion = false) {
const ngraph::Node::type_info_t& typeInfo = T::get_type_info_static();
return OperationPrecisionRestriction(typeInfo.name, specifiedVersion ? typeInfo.version : -1ull, precisionsByPort);
return OperationPrecisionRestriction(typeInfo.name, specifiedVersion ? typeInfo.version : -1ll, precisionsByPort);
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ class TRANSFORMATIONS_API ConcatTransformation : public LayerTransformation {
public:
ConcatTransformation(const Params& params = Params());
~ConcatTransformation() override {};
void registerMatcherIn(GraphRewrite& pass, TransformationContext& context) const override;
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) const override;
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const override;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ class TRANSFORMATIONS_API ConcatMultiChannelsTransformation : public ConcatTrans
public:
ConcatMultiChannelsTransformation(const Params& params);
~ConcatMultiChannelsTransformation() override {};
void registerMatcherIn(GraphRewrite& pass, TransformationContext& context) const override;
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) const override;
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ class TRANSFORMATIONS_API ConvertTransformation : public LayerTransformation {
public:
ConvertTransformation(const Params& params);
~ConvertTransformation() override {}
void registerMatcherIn(GraphRewrite& pass, TransformationContext& context) const override;
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) const override;
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ namespace low_precision {
class TRANSFORMATIONS_API ConvolutionTransformation : public WeightableLayerTransformation {
public:
ConvolutionTransformation(const Params& params = Params());
void registerMatcherIn(GraphRewrite& pass, TransformationContext& context) const override;
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) const override;
bool isQuantized(std::shared_ptr<Node> layer) const noexcept override;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ class TRANSFORMATIONS_API DepthToSpaceTransformation : public TransparentBaseTra
DepthToSpaceTransformation(const Params& params = Params());
~DepthToSpaceTransformation() override {}
bool transform(TransformationContext &context, ngraph::pattern::Matcher &m) const override;
void registerMatcherIn(GraphRewrite& pass, TransformationContext& context) const override;
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const override;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ namespace low_precision {
class TRANSFORMATIONS_API FakeQuantizeTransformation : public LayerTransformation {
public:
FakeQuantizeTransformation(const Params& params = Params());
void registerMatcherIn(GraphRewrite& pass, TransformationContext& context) const override;
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) const override;
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ class TRANSFORMATIONS_API FakeQuantizeDecompositionTransformation : public Layer
public:
FakeQuantizeDecompositionTransformation(const Params& params, TransformationContext& context);
FakeQuantizeDecompositionTransformation(const Params& params = Params());
void registerMatcherIn(GraphRewrite& pass, TransformationContext& context) const override;
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) const override;
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class TRANSFORMATIONS_API FoldConvertTransformation : public LayerTransformation
public:
FoldConvertTransformation(const Params& params) : LayerTransformation(params) {}
~FoldConvertTransformation() override {}
void registerMatcherIn(GraphRewrite& pass, TransformationContext& context) const override;
//void registerMatcherIn(GraphRewrite& pass, TransformationContext& context) const override;
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) const override;
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const override;
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ class TRANSFORMATIONS_API FuseConvertTransformation : public LayerTransformation
public:
FuseConvertTransformation(const Params& params);
~FuseConvertTransformation() override {}
void registerMatcherIn(GraphRewrite& pass, TransformationContext& context) const override;
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) const override;
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const override;
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ class TRANSFORMATIONS_API FuseFakeQuantizeTransformation : public LayerTransform
public:
FuseFakeQuantizeTransformation(const Params& params);
~FuseFakeQuantizeTransformation() override {}
void registerMatcherIn(GraphRewrite& pass, TransformationContext& context) const override;
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) const override;
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ class TRANSFORMATIONS_API FuseMultiplyToFakeQuantizeTransformation : public Laye
public:
FuseMultiplyToFakeQuantizeTransformation(const Params& params = Params());
~FuseMultiplyToFakeQuantizeTransformation() override {}
void registerMatcherIn(GraphRewrite& pass, TransformationContext& context) const override;
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) const override;
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const override;
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ class TRANSFORMATIONS_API FuseSubtractToFakeQuantizeTransformation : public Laye
public:
FuseSubtractToFakeQuantizeTransformation(const Params& params = Params());
~FuseSubtractToFakeQuantizeTransformation() override {}
void registerMatcherIn(GraphRewrite& pass, TransformationContext& context) const override;
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) const override;
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const override;
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ namespace low_precision {
class TRANSFORMATIONS_API GroupConvolutionTransformation : public ConvolutionTransformation {
public:
GroupConvolutionTransformation(const Params& params = Params());
void registerMatcherIn(GraphRewrite& pass, TransformationContext& context) const override;
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) const override;
bool isQuantized(std::shared_ptr<Node> layer) const noexcept override;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ class TRANSFORMATIONS_API InterpolateTransformation : public LayerTransformation
InterpolateTransformation(const Params& params = Params());
~InterpolateTransformation() override {}
bool transform(TransformationContext &context, ngraph::pattern::Matcher &m) const override;
void registerMatcherIn(GraphRewrite& pass, TransformationContext& context) const override;
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const override;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,6 @@ class TRANSFORMATIONS_API LayerTransformation : public ngraph::pass::MatcherPass

LayerTransformation(const Params& params);
virtual ~LayerTransformation() = default;
virtual void registerMatcherIn(ngraph::pass::GraphRewrite& pass, TransformationContext& context) const = 0;
virtual bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) const = 0;

void setParams(const Params& params);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ class TRANSFORMATIONS_API MatMulTransformation : public LayerTransformation {
MatMulTransformation(const Params& params = Params());
~MatMulTransformation() override {}
bool transform(TransformationContext &context, ngraph::pattern::Matcher &m) const override;
void registerMatcherIn(GraphRewrite& pass, TransformationContext& context) const override;
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const override;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ namespace low_precision {
class TRANSFORMATIONS_API MaxPoolTransformation : public LayerTransformation {
public:
MaxPoolTransformation(const Params& params = Params());
void registerMatcherIn(GraphRewrite& pass, TransformationContext& context) const override;
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> op) const override;
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) const override;
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ class TRANSFORMATIONS_API MultiplyTransformation : public EltwiseBaseTransformat
public:
MultiplyTransformation(const Params& params = Params());
~MultiplyTransformation() override {}
void registerMatcherIn(GraphRewrite& pass, TransformationContext& context) const override;
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) const override;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ class TRANSFORMATIONS_API MultiplyToGroupConvolutionTransformation : public Laye
public:
MultiplyToGroupConvolutionTransformation(const Params& params = Params());
~MultiplyToGroupConvolutionTransformation() override {}
void registerMatcherIn(GraphRewrite& pass, TransformationContext& context) const override;
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) const override;
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const override;
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ namespace low_precision {
class TRANSFORMATIONS_API MVNTransformation : public LayerTransformation {
public:
MVNTransformation(const Params& params = Params());
void registerMatcherIn(GraphRewrite& pass, TransformationContext& context) const override;
bool transform(TransformationContext &context, ngraph::pattern::Matcher &m) const override;
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const override;
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ std::shared_ptr<ngraph::VariantWrapper<T>> getAttribute(const std::shared_ptr<No
auto attribute = std::dynamic_pointer_cast<ngraph::VariantWrapper<T>>(it->second);
assert(attribute != nullptr);
return attribute;
};
}

} // namespace low_precision
} // namespace pass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ namespace low_precision {
class TRANSFORMATIONS_API NormalizeL2Transformation : public LayerTransformation {
public:
NormalizeL2Transformation(const Params& params = Params());
void registerMatcherIn(GraphRewrite& pass, TransformationContext& context) const override;
bool transform(TransformationContext &context, ngraph::pattern::Matcher &m) const override;
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const override;
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ class TRANSFORMATIONS_API PReluTransformation : public LayerTransformation {
public:
PReluTransformation(const Params& params = Params());
~PReluTransformation() override {}
void registerMatcherIn(GraphRewrite& pass, TransformationContext& context) const override;
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) const override;
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> op) const override;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ class TRANSFORMATIONS_API ReluTransformation : public LayerTransformation {
public:
ReluTransformation(const Params& params = Params());
~ReluTransformation() override {}
void registerMatcherIn(GraphRewrite& pass, TransformationContext& context) const override;
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) const override;
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> op) const override;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ class TRANSFORMATIONS_API ReshapeTransformation : public LayerTransformation {
public:
ReshapeTransformation(const Params& params = Params());
~ReshapeTransformation() override {}
void registerMatcherIn(GraphRewrite& pass, TransformationContext& context) const override;
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) const override;
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> op) const override;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class TRANSFORMATIONS_API ngraph::VariantWrapper<std::shared_ptr<IntervalsAlignm

std::shared_ptr<ngraph::Variant> merge(const ngraph::NodeVector& nodes) override;

std::shared_ptr<IntervalsAlignmentAttribute> get() const { return this->m_value; };
std::shared_ptr<IntervalsAlignmentAttribute> get() const { return this->m_value; }

std::string get_string() override;
};
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@ class PrecisionPreservedAttribute {
public:
class SharedValue {
public:
SharedValue(const bool value) : value(value) {}
SharedValue(const bool value, const std::string& operationName) : value(value), operationName(operationName) {}
std::string operationName;
bool value;

SharedValue(const bool value) : value(value) {}
SharedValue(const bool value, const std::string& operationName) : value(value), operationName(operationName) {}
};

PrecisionPreservedAttribute(const bool value, const std::string& operationName) : sharedValue(std::make_shared<SharedValue>(value, operationName)) {}
Expand Down Expand Up @@ -54,7 +55,7 @@ class TRANSFORMATIONS_API ngraph::VariantWrapper<PrecisionPreservedAttribute> :

std::shared_ptr<ngraph::Variant> merge(const ngraph::NodeVector& nodes) override;

PrecisionPreservedAttribute get() { return this->m_value; };
PrecisionPreservedAttribute get() { return this->m_value; }

std::string get_string() override;
};
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class TRANSFORMATIONS_API ngraph::VariantWrapper<std::shared_ptr<PrecisionsAttri

std::shared_ptr<ngraph::Variant> init(const std::shared_ptr<ngraph::Node>& node) override;

std::shared_ptr<PrecisionsAttribute> get() { return this->m_value; };
std::shared_ptr<PrecisionsAttribute> get() { return this->m_value; }

virtual std::string get_string();
};
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class TRANSFORMATIONS_API ngraph::VariantWrapper<std::shared_ptr<QuantizationAli

std::shared_ptr<ngraph::Variant> init(const std::shared_ptr<ngraph::Node>& node) override;

std::shared_ptr<QuantizationAlignmentAttribute> get() { return this->m_value; };
std::shared_ptr<QuantizationAlignmentAttribute> get() { return this->m_value; }

std::string get_string() override;
};
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ namespace low_precision {
class TRANSFORMATIONS_API SplitTransformation : public LayerTransformation {
public:
SplitTransformation(const Params& params);
void registerMatcherIn(GraphRewrite& pass, TransformationContext& context) const override;
bool transform(TransformationContext& context, ngraph::pattern::Matcher& m) const override;
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const override;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ namespace low_precision {
class TRANSFORMATIONS_API SqueezeTransformation : public LayerTransformation {
public:
SqueezeTransformation(const Params& params = Params());
void registerMatcherIn(GraphRewrite& pass, TransformationContext& context) const override;
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) const override;
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const override;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ namespace low_precision {
class TRANSFORMATIONS_API StridedSliceTransformation : public LayerTransformation {
public:
StridedSliceTransformation(const Params& params);
void registerMatcherIn(GraphRewrite& pass, TransformationContext& context) const override;
//void registerMatcherIn(GraphRewrite& pass, TransformationContext& context) const override;
bool transform(TransformationContext& context, ngraph::pattern::Matcher& m) const override;
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> op) const override;
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;
Expand Down
Loading