Skip to content

Commit

Permalink
[LPT] FQ decomposition in common pass manager - exporation #3
Browse files Browse the repository at this point in the history
  • Loading branch information
eshoguli committed Jun 24, 2021
1 parent 557d46c commit a29fdb7
Show file tree
Hide file tree
Showing 76 changed files with 88 additions and 82 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class LP_TRANSFORMATIONS_API AddTransformation : public EltwiseBaseTransformatio
public:
NGRAPH_RTTI_DECLARATION;
AddTransformation(const Params& params = Params());
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) const override;
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) override;
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const override;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class LP_TRANSFORMATIONS_API AvgPoolTransformation : public LayerTransformation
public:
NGRAPH_RTTI_DECLARATION;
AvgPoolTransformation(const Params& params = Params());
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) const override;
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) override;
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const override;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class LP_TRANSFORMATIONS_API ClampTransformation : public LayerTransformation {
public:
NGRAPH_RTTI_DECLARATION;
ClampTransformation(const Params& params = Params());
bool transform(TransformationContext& context, ngraph::pattern::Matcher& m) const override;
bool transform(TransformationContext& context, ngraph::pattern::Matcher& m) override;
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> op) const override;
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class LP_TRANSFORMATIONS_API ConcatTransformation : public LayerTransformation {
public:
NGRAPH_RTTI_DECLARATION;
ConcatTransformation(const Params& params = Params());
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) const override;
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) override;
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const override;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class LP_TRANSFORMATIONS_API ConvertTransformation : public LayerTransformation
public:
NGRAPH_RTTI_DECLARATION;
ConvertTransformation(const Params& params = Params());
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) const override;
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) override;
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class LP_TRANSFORMATIONS_API ConvolutionTransformation : public WeightableLayerT
public:
NGRAPH_RTTI_DECLARATION;
ConvolutionTransformation(const Params& params = Params());
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) const override;
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) override;
bool isQuantized(const std::shared_ptr<const Node>& layer) const noexcept override;
static bool isQuantizedStatic(const std::shared_ptr<const Node>& layer) noexcept;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class LP_TRANSFORMATIONS_API ConvolutionBackpropDataTransformation : public Weig
public:
ConvolutionBackpropDataTransformation(const Params& params = Params());
//void registerMatcherIn(GraphRewrite& pass, TransformationContext& context) const override;
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) const override;
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) override;
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> op) const override;
bool isQuantized(const std::shared_ptr<const Node>& layer) const noexcept override;
static bool isQuantizedStatic(const std::shared_ptr<const Node>& layer) noexcept;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class LP_TRANSFORMATIONS_API DepthToSpaceTransformation : public TransparentBase
public:
NGRAPH_RTTI_DECLARATION;
DepthToSpaceTransformation(const Params& params = Params());
bool transform(TransformationContext &context, ngraph::pattern::Matcher &m) const override;
bool transform(TransformationContext &context, ngraph::pattern::Matcher &m) override;
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const override;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,16 @@ class LP_TRANSFORMATIONS_API FakeQuantizeTransformation : public LayerTransforma
public:
NGRAPH_RTTI_DECLARATION;
FakeQuantizeTransformation(const Params& params = Params());
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) const override;
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) override;
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;

static bool checkElementwise(const std::shared_ptr<Node>& eltwise);

private:
std::shared_ptr<opset1::FakeQuantize> fuseElementwise(TransformationContext& context, const std::shared_ptr<opset1::FakeQuantize>& fakeQuantize) const;
std::shared_ptr<opset1::FakeQuantize> fuseElementwise(
TransformationContext& context,
MatcherPass& matcherPass,
const std::shared_ptr<opset1::FakeQuantize>& fakeQuantize) const;
};

} // namespace low_precision
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class LP_TRANSFORMATIONS_API FakeQuantizeDecompositionTransformation : public La
public:
NGRAPH_RTTI_DECLARATION;
FakeQuantizeDecompositionTransformation(const Params& params = Params());
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) const override;
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) override;
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class LP_TRANSFORMATIONS_API FoldConvertTransformation : public LayerTransformat
public:
NGRAPH_RTTI_DECLARATION;
FoldConvertTransformation(const Params& params = Params());
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) const override;
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) override;
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const override;
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class LP_TRANSFORMATIONS_API FoldFakeQuantizeTransformation : public LayerTransf
public:
NGRAPH_RTTI_DECLARATION;
FoldFakeQuantizeTransformation(const Params& params = Params());
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) const override;
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) override;
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const override;
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class LP_TRANSFORMATIONS_API FuseConvertTransformation : public LayerTransformat
public:
NGRAPH_RTTI_DECLARATION;
FuseConvertTransformation(const Params& params = Params());
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) const override;
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) override;
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const override;
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class LP_TRANSFORMATIONS_API FuseFakeQuantizeTransformation : public LayerTransf
public:
NGRAPH_RTTI_DECLARATION;
FuseFakeQuantizeTransformation(const Params& params);
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) const override;
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) override;
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;

private:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class LP_TRANSFORMATIONS_API FuseMultiplyToFakeQuantizeTransformation : public L
public:
NGRAPH_RTTI_DECLARATION;
FuseMultiplyToFakeQuantizeTransformation(const Params& params = Params());
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) const override;
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) override;
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const override;
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class LP_TRANSFORMATIONS_API FuseSubtractToFakeQuantizeTransformation : public L
public:
NGRAPH_RTTI_DECLARATION;
FuseSubtractToFakeQuantizeTransformation(const Params& params = Params());
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) const override;
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) override;
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const override;
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class LP_TRANSFORMATIONS_API GroupConvolutionTransformation : public Convolution
public:
NGRAPH_RTTI_DECLARATION;
GroupConvolutionTransformation(const Params& params = Params());
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) const override;
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) override;
bool isQuantized(const std::shared_ptr<const Node>& layer) const noexcept override;
static bool isQuantizedStatic(const std::shared_ptr<const Node>& layer) noexcept;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class LP_TRANSFORMATIONS_API InterpolateTransformation : public LayerTransformat
public:
NGRAPH_RTTI_DECLARATION;
InterpolateTransformation(const Params& params = Params());
bool transform(TransformationContext &context, ngraph::pattern::Matcher &m) const override;
bool transform(TransformationContext &context, ngraph::pattern::Matcher &m) override;
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const override;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ class LP_TRANSFORMATIONS_API LayerTransformation : public ngraph::pass::MatcherP

LayerTransformation(const Params& params);
virtual ~LayerTransformation() = default;
virtual bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) const = 0;
virtual bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) = 0;

void setContext(TransformationContext* context) noexcept;

Expand Down Expand Up @@ -265,7 +265,7 @@ class LP_TRANSFORMATIONS_API LayerTransformation : public ngraph::pass::MatcherP
std::shared_ptr<ngraph::Node> lastNode,
std::string originalName) const;

void addPattern(ngraph::pass::GraphRewrite& pass, TransformationContext& context, std::shared_ptr<Node> patternRoot) const;
void addPattern(ngraph::pass::GraphRewrite& pass, TransformationContext& context, std::shared_ptr<Node> patternRoot);

//TODO: replace with canBeTransformed when quantization by special dimension is supported for all transformations
bool canBeTransformedSpatialDimension(const TransformationContext& context, std::shared_ptr<Node> layer) const;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class LP_TRANSFORMATIONS_API MatMulTransformation : public LayerTransformation {
public:
NGRAPH_RTTI_DECLARATION;
MatMulTransformation(const Params& params = Params());
bool transform(TransformationContext &context, ngraph::pattern::Matcher &m) const override;
bool transform(TransformationContext &context, ngraph::pattern::Matcher &m) override;
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const override;
static bool is3DTensorOnActivations(const std::shared_ptr<const Node>& node);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class LP_TRANSFORMATIONS_API MaxPoolTransformation : public LayerTransformation
NGRAPH_RTTI_DECLARATION;
MaxPoolTransformation(const Params& params = Params());
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> op) const override;
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) const override;
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) override;
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class LP_TRANSFORMATIONS_API MultiplyTransformation : public EltwiseBaseTransfor
public:
NGRAPH_RTTI_DECLARATION;
MultiplyTransformation(const Params& params = Params());
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) const override;
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) override;
};

} // namespace low_precision
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class LP_TRANSFORMATIONS_API MultiplyToGroupConvolutionTransformation : public L
const Params& params = Params(),
const OperationPrecisionRestriction::PrecisionsByPort& restrictions = {});
~MultiplyToGroupConvolutionTransformation() override {}
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) const override;
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) override;
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const override;
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;
bool isQuantized(const std::shared_ptr<const Node>& layer) const noexcept override;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class LP_TRANSFORMATIONS_API MVNTransformation : public LayerTransformation {
public:
NGRAPH_RTTI_DECLARATION;
MVNTransformation(const Params& params = Params());
bool transform(TransformationContext &context, ngraph::pattern::Matcher &m) const override;
bool transform(TransformationContext &context, ngraph::pattern::Matcher &m) override;
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const override;
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class LP_TRANSFORMATIONS_API NormalizeL2Transformation : public LayerTransformat
public:
NGRAPH_RTTI_DECLARATION;
NormalizeL2Transformation(const Params& params = Params());
bool transform(TransformationContext &context, ngraph::pattern::Matcher &m) const override;
bool transform(TransformationContext &context, ngraph::pattern::Matcher &m) override;
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const override;
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class LP_TRANSFORMATIONS_API PReluTransformation : public LayerTransformation {
public:
NGRAPH_RTTI_DECLARATION;
PReluTransformation(const Params& params = Params());
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) const override;
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) override;
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> op) const override;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ namespace low_precision {
class LP_TRANSFORMATIONS_API ReduceBaseTransformation : public LayerTransformation {
public:
ReduceBaseTransformation(const Params& params = Params());
bool transform(TransformationContext& context, ngraph::pattern::Matcher& m) const override;
bool transform(TransformationContext& context, ngraph::pattern::Matcher& m) override;
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> reduce) const override;

protected:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class LP_TRANSFORMATIONS_API ReluTransformation : public LayerTransformation {
public:
NGRAPH_RTTI_DECLARATION;
ReluTransformation(const Params& params = Params());
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) const override;
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) override;
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> op) const override;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class LP_TRANSFORMATIONS_API ReshapeTransformation : public LayerTransformation
public:
NGRAPH_RTTI_DECLARATION;
ReshapeTransformation(const Params& params = Params());
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) const override;
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) override;
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> op) const override;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class LP_TRANSFORMATIONS_API ShuffleChannelsTransformation : public LayerTransfo
public:
NGRAPH_RTTI_DECLARATION;
ShuffleChannelsTransformation(const Params& params = Params());
bool transform(TransformationContext& context, ngraph::pattern::Matcher& m) const override;
bool transform(TransformationContext& context, ngraph::pattern::Matcher& m) override;
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> op) const override;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class LP_TRANSFORMATIONS_API SplitTransformation : public LayerTransformation {
public:
NGRAPH_RTTI_DECLARATION;
SplitTransformation(const Params& params = Params());
bool transform(TransformationContext& context, ngraph::pattern::Matcher& m) const override;
bool transform(TransformationContext& context, ngraph::pattern::Matcher& m) override;
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const override;
void updateOutputs(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class LP_TRANSFORMATIONS_API SqueezeTransformation : public LayerTransformation
public:
NGRAPH_RTTI_DECLARATION;
SqueezeTransformation(const Params& params = Params());
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) const override;
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) override;
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const override;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class LP_TRANSFORMATIONS_API StridedSliceTransformation : public LayerTransforma
public:
NGRAPH_RTTI_DECLARATION;
StridedSliceTransformation(const Params& params = Params());
bool transform(TransformationContext& context, ngraph::pattern::Matcher& m) const override;
bool transform(TransformationContext& context, ngraph::pattern::Matcher& m) override;
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> op) const override;
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;
};
Expand Down
Loading

0 comments on commit a29fdb7

Please sign in to comment.