-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Relay][TOPI] Gluncv SSD support on the GPU (#2784)
* ssd gluoncv gpu op updated * ssd gluoncv gpu op updated * tutorials and testes modified * tutorials and testes modified * fix lint * fix lint * address comment * multibox bug fixed * space line added * use less threads per block * use less threads per block * less threads per block for get valid count * less threads per block for get valid count * merge with master * Revert "less threads per block for get valid count" This reverts commit 08896cf. * Revert "less threads per block for get valid count" This reverts commit 08896cf. * typo fixed * elem length made to a variable * fix lint error * fix lint error * lint fixed * bug fixed * bug fixed * lint fixed * error fixed * error fixed * test ci * test ci * seperate argsort to be an independent op * seperate argsort to be an independent op * fix lint * fix lint * remove unsupported models * typo fixed * argsort added to realy * solve conflicts with master * fix lint * fix lint * test push * Revert "test push" This reverts commit 6db0088. * fix lint error * fix more lint * cpu test_sort udpated * debug ci * nms fixed * expose argsort to relay frontend * test ci * fix lint * sort register error fixed * fix nnvm * nms type fixed * adaptive pooling added to relay * Revert "adaptive pooling added to relay" This reverts commit 1119f1f. * fix lint * expose argsort op * fix lint * fix lint * fix lint * sort test updated * sort bug fixed * nnvm error fixed * fix argsort default data type returned to be float insteaf of int * fix lint * fix lint * test fixed * fix valid count * fix titanx bug * tutorial add both targets * titanx error fixed * try to fix CI old gpu error * try to solve CI GPU error * get_valid_count added * reverse get_valid_count * get valid count optimized * address comments * fix ci error * remove unessesary block sync * add back one sync * address comments * address more comments * more comments * move sort to be indepent algorithm * typo fixed * more typos * comments addressed * doc updated * fix pylint * address final comments * apache license added
- Loading branch information
Showing
34 changed files
with
1,731 additions
and
372 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
/*! | ||
* \file tvm/relay/attrs/vision.h | ||
* \brief Auxiliary attributes for vision operators. | ||
*/ | ||
#ifndef TVM_RELAY_ATTRS_ALGORITHM_H_ | ||
#define TVM_RELAY_ATTRS_ALGORITHM_H_ | ||
|
||
#include <tvm/attrs.h> | ||
#include <string> | ||
|
||
namespace tvm { | ||
namespace relay { | ||
|
||
/*! \brief Attributes used in argsort operators */ | ||
struct ArgsortAttrs : public tvm::AttrsNode<ArgsortAttrs> { | ||
int axis; | ||
bool is_ascend; | ||
DataType dtype; | ||
|
||
TVM_DECLARE_ATTRS(ArgsortAttrs, "relay.attrs.ArgsortAttrs") { | ||
TVM_ATTR_FIELD(axis).set_default(-1) | ||
.describe("Axis along which to sort the input tensor." | ||
"If not given, the flattened array is used."); | ||
TVM_ATTR_FIELD(is_ascend).set_default(true) | ||
.describe("Whether to sort in ascending or descending order." | ||
"By default, sort in ascending order"); | ||
TVM_ATTR_FIELD(dtype).set_default(NullValue<DataType>()) | ||
.describe("DType of the output indices."); | ||
} | ||
}; | ||
|
||
} // namespace relay | ||
} // namespace tvm | ||
#endif // TVM_RELAY_ATTRS_ALGORITHM_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
"Definition of classic algorithms" | ||
# pylint: disable=invalid-name,unused-argument | ||
from __future__ import absolute_import | ||
|
||
import topi | ||
from topi.util import get_const_int | ||
from ..op import OpPattern, register_compute, register_schedule, register_pattern | ||
|
||
|
||
@register_schedule("argsort") | ||
def schedule_argsort(_, outs, target): | ||
"""Schedule definition of argsort""" | ||
with target: | ||
return topi.generic.schedule_argsort(outs) | ||
|
||
|
||
@register_compute("argsort") | ||
def compute_argsort(attrs, inputs, _, target): | ||
"""Compute definition of argsort""" | ||
axis = get_const_int(attrs.axis) | ||
is_ascend = bool(get_const_int(attrs.is_ascend)) | ||
dtype = str(attrs.dtype) | ||
return [ | ||
topi.argsort(inputs[0], None, axis=axis, is_ascend=is_ascend, \ | ||
dtype=dtype, flag=False) | ||
] | ||
|
||
|
||
register_pattern("argsort", OpPattern.OPAQUE) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
"""Classic algorithm operation""" | ||
from __future__ import absolute_import as _abs | ||
from . import _make | ||
|
||
def argsort(data, axis=-1, is_ascend=1, dtype="float32"): | ||
"""Performs sorting along the given axis and returns an array of indicies | ||
having same shape as an input array that index data in sorted order. | ||
Parameters | ||
---------- | ||
data : relay.Expr | ||
The input data tensor. | ||
valid_count : tvm.Tensor | ||
The number of valid elements to be sorted. | ||
axis : int, optional | ||
Axis long which to sort the input tensor. | ||
is_ascend : boolean, optional | ||
Whether to sort in ascending or descending order. | ||
dtype : string, optional | ||
DType of the output indices. | ||
Returns | ||
------- | ||
out : relay.Expr | ||
Tensor with same shape as data. | ||
""" | ||
return _make.argsort(data, axis, is_ascend, dtype) |
Oops, something went wrong.