forked from apache/tvm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[RELAY][DYN] Dynamic broadcast_to, zeros, ones (apache#6007)
* Dynamic BroadcastTo * fixed lint! * add test_one_hot() back * add one_hot registration back * Dynamic BroadcastTo * fixed lint! * add one_hot registration back * fixed lint.. again * fixed lint * lint * responding to comments * skipping cuda in dynamic test * skipping cuda in dynamic test * fixed i386 test and GPU test * lint * starting ones and zeros * fixed dynamic ones and zeros, wrote dyn ones and zeros test * added static version of zeros, ones and added a check for size of types to static BroadCastToRel * added dynamic to static pass for zeros and ones, dynamic test and dynamic to static test * removed op_str in dyn to static pass test * fixed lint * fix lint hopefully * removed import const * removed import that was actually used * copy all attributes from broadcast_to, ones, zeros, full * responding to comments * fixed build error * finishing rebase * fix lint Co-authored-by: Lily Orth-Smith <[email protected]>
- Loading branch information
Showing
16 changed files
with
384 additions
and
76 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,3 +19,4 @@ | |
|
||
from . import _algorithm | ||
from . import _transform | ||
from . import _tensor |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
#pylint: disable=invalid-name, unused-argument, len-as-condition | ||
"""Backend compiler related feature registration for dynamic ops""" | ||
|
||
import topi | ||
|
||
from ..op import register_shape_func, register_compute | ||
from ..op import register_broadcast_schedule | ||
from ..op import register_pattern, OpPattern | ||
from .._tensor import full_shape_func, no_data_full_shape_func | ||
|
||
# ones | ||
@register_compute("dyn.ones") | ||
def ones_compute(attrs, inputs, output_type): | ||
assert len(inputs) == 1 | ||
return [topi.full(output_type.shape, output_type.dtype, 1.0)] | ||
|
||
register_broadcast_schedule("dyn.ones") | ||
register_pattern("dyn.ones", OpPattern.ELEMWISE) | ||
|
||
@register_compute("dyn.zeros") | ||
def zeros_compute(attrs, inputs, output_type): | ||
assert len(inputs) == 1 | ||
return [topi.full(output_type.shape, output_type.dtype, 0.0)] | ||
|
||
register_broadcast_schedule("dyn.zeros") | ||
register_pattern("dyn.zeros", OpPattern.ELEMWISE) | ||
|
||
register_shape_func("dyn.broadcast_to", True, full_shape_func) | ||
register_shape_func("dyn.ones", True, no_data_full_shape_func) | ||
register_shape_func("dyn.zeros", True, no_data_full_shape_func) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.