Skip to content

Commit

Permalink
[LPT] Interpolate transformation enabled
Browse files Browse the repository at this point in the history
  • Loading branch information
vzinovie committed Oct 5, 2020
1 parent 848c839 commit d346efb
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,23 +22,28 @@ const std::vector<LayerTestsUtils::LayerTransformation::LptVersion> versions = {
LayerTestsUtils::LayerTransformation::LptVersion::nGraph
};

const std::vector<ngraph::op::InterpolateAttrs> interpAttrs = {
// {
// ngraph::AxisSet{2, 3},
// "nearest",
// false,
// false,
// {0},
// {0}
// },
// {
// ngraph::AxisSet{2, 3},
// "nearest",
// false,
// true,
// {0},
// {0}
// },
const std::vector<interpAttributes> interpAttrs = {
interpAttributes(
ngraph::AxisSet{2, 3},
"nearest",
false,
false,
{0},
{0}),
interpAttributes(
ngraph::AxisSet{2, 3},
"nearest",
false,
true,
{0},
{0}),
interpAttributes(
ngraph::AxisSet{2, 3},
"linear",
false,
false,
{0},
{0}),
};

const auto combineValues = ::testing::Combine(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,35 @@

namespace LayerTestsDefinitions {

class interpAttributes {
public:
ngraph::AxisSet axes;
std::string mode;
bool align_corners;
bool antialias;
std::vector<size_t> pads_begin;
std::vector<size_t> pads_end;

bool shouldBeTransformed;

interpAttributes() = default;

interpAttributes(const ngraph::AxisSet& axes,
const std::string& mode,
const bool& align_corners,
const bool& antialias,
const std::vector<size_t>& pads_begin,
const std::vector<size_t>& pads_end,
const bool& shouldBeTransformed = true) :
axes(axes), mode(mode), align_corners(align_corners),
antialias(antialias), pads_begin(pads_begin), pads_end(pads_end) {}
};

typedef std::tuple<
ngraph::element::Type,
std::pair<ngraph::Shape, ngraph::Shape>,
std::string,
ngraph::op::InterpolateAttrs,
interpAttributes,
LayerTestsUtils::LayerTransformation::LptVersion> InterpolateTransformationParams;

class InterpolateTransformation :
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,46 +17,54 @@ namespace LayerTestsDefinitions {

template <typename T>
inline std::ostream& operator<<(std::ostream& os, const std::vector<T>& values) {
os << "{ ";
os << "{";
for (size_t i = 0; i < values.size(); ++i) {
os << values[i];
if (i != (values.size() - 1ul)) {
os << ", ";
os << ",";
}
}
os << " }";
os << "}";
return os;
}

std::string InterpolateTransformation::getTestCaseName(testing::TestParamInfo<InterpolateTransformationParams> obj) {
ngraph::element::Type precision;
std::pair<ngraph::Shape, ngraph::Shape> shapes;
std::string targetDevice;
ngraph::op::InterpolateAttrs interpAttrs;
interpAttributes attributes;
auto params = LayerTestsUtils::LayerTransformationParamsNGraphFactory::createParamsU8I8();
LayerTestsUtils::LayerTransformation::LptVersion version;
std::tie(precision, shapes, targetDevice, interpAttrs, version) = obj.param;
std::tie(precision, shapes, targetDevice, attributes, version) = obj.param;

std::ostringstream result;
result << getTestCaseNameByParams(precision, shapes.first, targetDevice, params, version) <<
"_" << shapes.second << "_" <<
interpAttrs.align_corners <<
interpAttrs.antialias <<
interpAttrs.axes <<
interpAttrs.mode <<
interpAttrs.pads_begin <<
interpAttrs.pads_end;
result << getTestCaseNameByParams(precision, shapes.first, targetDevice, params, version) << "_" <<
shapes.second << "_" <<
attributes.align_corners << "_" <<
attributes.antialias << "_" <<
attributes.axes << "_" <<
attributes.mode << "_" <<
attributes.pads_begin << "_" <<
attributes.pads_end;
return result.str();
}

void InterpolateTransformation::SetUp() {
SetRefMode(LayerTestsUtils::RefMode::IE);
ngraph::element::Type precision;
std::pair<ngraph::Shape, ngraph::Shape> shapes;
ngraph::op::InterpolateAttrs interpAttrs;
interpAttributes attributes;
auto params = LayerTestsUtils::LayerTransformationParamsNGraphFactory::createParamsU8I8();
LayerTestsUtils::LayerTransformation::LptVersion version;
std::tie(precision, shapes, targetDevice, interpAttrs, version) = this->GetParam();
std::tie(precision, shapes, targetDevice, attributes, version) = this->GetParam();

ngraph::op::InterpolateAttrs interpAttrs;
interpAttrs.axes = attributes.axes;
interpAttrs.mode = attributes.mode;
interpAttrs.align_corners = attributes.align_corners;
interpAttrs.antialias = attributes.antialias;
interpAttrs.pads_begin = attributes.pads_begin;
interpAttrs.pads_end = attributes.pads_end;

ConfigurePlugin(version);

Expand All @@ -71,7 +79,7 @@ void InterpolateTransformation::validate() {
ngraph::element::Type precision;
std::pair<ngraph::Shape, ngraph::Shape> shapes;
std::string targetDevice;
ngraph::op::InterpolateAttrs interpAttrs;
interpAttributes interpAttrs;
auto params = LayerTestsUtils::LayerTransformationParamsNGraphFactory::createParamsU8I8();
LayerTestsUtils::LayerTransformation::LptVersion version;
std::tie(precision, shapes, targetDevice, interpAttrs, version) = this->GetParam();
Expand All @@ -86,16 +94,16 @@ void InterpolateTransformation::validate() {
std::map<std::string, InferenceEngine::DataPtr>::iterator it = outputs.begin();
const InferenceEngine::CNNLayerPtr outputLayer = getCreatorLayer(it->second).lock();
EXPECT_TRUE(outputLayer != nullptr);
EXPECT_EQ("ScaleShift", outputLayer->type);
EXPECT_EQ(interpAttrs.mode == "linear" ? "Interp" : "ScaleShift", outputLayer->type);

EXPECT_EQ(1ul, outputLayer->insData.size());
const InferenceEngine::DataPtr insData = outputLayer->insData[0].lock();
EXPECT_TRUE(insData != nullptr);
const InferenceEngine::CNNLayerPtr interpolate = getCreatorLayer(insData).lock();
EXPECT_TRUE(interpolate != nullptr);
EXPECT_EQ("Resample", interpolate->type);
EXPECT_EQ(interpAttrs.mode == "linear" ? "ScaleShift" : "Resample", interpolate->type);

if (params.updatePrecisions) {
if (params.updatePrecisions && interpAttrs.mode == "nearest") {
const InferenceEngine::Precision precision = interpolate->outData[0]->getTensorDesc().getPrecision();
EXPECT_TRUE((precision == InferenceEngine::Precision::U8) || (precision == InferenceEngine::Precision::I8));
}
Expand Down

0 comments on commit d346efb

Please sign in to comment.