-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
d3eb9cb
commit 41c82e4
Showing
8 changed files
with
317 additions
and
125 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
/*! | ||
* Copyright (c) 2019 by Contributors | ||
* \file src/relay/op/tensor/transform.h | ||
* \brief Tranform op attributes that can be shared among Relay and its dialects. | ||
*/ | ||
#ifndef TVM_RELAY_OP_TENSOR_TRANSFORM_H_ | ||
#define TVM_RELAY_OP_TENSOR_TRANSFORM_H_ | ||
|
||
#include <algorithm> | ||
#include <limits> | ||
#include <string> | ||
#include <unordered_set> | ||
#include <utility> | ||
#include <vector> | ||
|
||
namespace tvm { | ||
namespace relay { | ||
|
||
template <typename AttrType> | ||
bool ConcatenateRel(const Array<Type>& types, | ||
int num_inputs, | ||
const Attrs& attrs, | ||
const TypeReporter& reporter) { | ||
// types: [data, result] | ||
CHECK_EQ(types.size(), 2); | ||
/* If we receive a tuple we can continue, if we receive | ||
* anything but an incomplete type we should signal an | ||
* error. | ||
*/ | ||
const auto* tensor_tuple = types[0].as<TupleTypeNode>(); | ||
if (tensor_tuple == nullptr) { | ||
throw relay::Error( | ||
RELAY_ERROR( | ||
"concatenate requires a tuple of tensors as the first argument, found " | ||
<< PrettyPrint(types[0]))); | ||
} else if (types[0].as<IncompleteTypeNode>() != nullptr) { | ||
return false; | ||
} | ||
|
||
const auto* param = attrs.as<AttrType>(); | ||
CHECK(param != nullptr); | ||
if (tensor_tuple->fields[0].as<IncompleteTypeNode>()) { | ||
return false; | ||
} | ||
const auto& first = Downcast<TensorType>(tensor_tuple->fields[0]); | ||
// Sanity check: ndim and dtype. | ||
const int ndim = static_cast<int>(first->shape.size()); | ||
const DataType dtype = first->dtype; | ||
|
||
for (const Type& ele : tensor_tuple->fields) { | ||
if (ele.as<IncompleteTypeNode>()) { | ||
return false; | ||
} | ||
|
||
const auto& e = Downcast<TensorType>(ele); | ||
|
||
int e_ndim = static_cast<int>(e->shape.size()); | ||
const DataType& e_dtype = e->dtype; | ||
if (e_ndim != ndim) { | ||
throw relay::Error("relay.concatenate requires all tensors have the same ndim"); | ||
} | ||
if (e_dtype != dtype) { | ||
throw relay::Error("relay.concatenate requires all tensors have the same dtype"); | ||
} | ||
} | ||
// Sanity check: axis | ||
int axis = param->axis; | ||
if (!(-ndim <= axis && axis < ndim)) { | ||
throw relay::Error(RELAY_ERROR( | ||
"concatenate only accepts `axis` in [-ndim, ndim)" << | ||
", but got axis = " << axis << | ||
", and ndim = " << ndim)); | ||
} | ||
axis = axis < 0 ? ndim + axis : axis; | ||
// Calculate shape | ||
std::vector<IndexExpr> oshape(first->shape.begin(), first->shape.end()); | ||
IndexExpr &concat_dim = oshape[axis]; | ||
bool has_any = false; | ||
if (concat_dim.as<Any>()) { | ||
has_any = true; | ||
} else { | ||
for (int i = 1; i < static_cast<int>(tensor_tuple->fields.size()); ++i) { | ||
const auto& e = Downcast<TensorType>(tensor_tuple->fields[i]); | ||
if (e->shape[axis].as<Any>()) { | ||
has_any = true; | ||
break; | ||
} | ||
concat_dim += e->shape[axis]; | ||
} | ||
} | ||
|
||
if (has_any) { | ||
concat_dim = Any::make(); | ||
} | ||
|
||
auto rtype = TensorTypeNode::make(oshape, dtype); | ||
reporter->Assign(types[1], rtype); | ||
return true; | ||
} | ||
|
||
} // namespace relay | ||
} // namespace tvm | ||
#endif // TVM_RELAY_OP_TENSOR_TRANSFORM_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.