Skip to content

Commit

Permalink
Address PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
elvin-n committed May 4, 2022
1 parent fb29643 commit 7923f71
Show file tree
Hide file tree
Showing 8 changed files with 74 additions and 65 deletions.
2 changes: 1 addition & 1 deletion python/tvm/topi/adreno/conv2d_alter_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name,unused-variable,unused-argument,no-member
"""Conv2D alter op and legalize functions for x86"""
"""Conv2D alter op for Qualcomm Adreno GPU"""

import logging

Expand Down
2 changes: 1 addition & 1 deletion python/tvm/topi/adreno/conv2d_nchw.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def compute_conv2d_NCHWc_KCRSk(Input, Filter, stride, padding, dilation, out_dty
tag="conv2d_nchwc",
)

if not convert_from4d and autotvm.GLOBAL_SCOPE.in_tuning:
if convert_from4d and not autotvm.GLOBAL_SCOPE.in_tuning:
dummy_cast = te.compute(
(batch, out_channel_chunks, out_height_orig, out_width_orig, out_channel_block),
lambda n, fc, y, x, fb: conv[n, fc, y, x, fb].astype(out_dtype),
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/topi/adreno/conv2d_nhwc.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def compute_conv2d_NHWC_HWIO(Input, Filter, stride, padding, dilation, out_dtype
tag="conv2d_nhwc",
)

if not convert_from4d and autotvm.GLOBAL_SCOPE.in_tuning:
if convert_from4d and not autotvm.GLOBAL_SCOPE.in_tuning:
dummy_cast = te.compute(
(batch, out_height_orig, out_width_orig, out_channel_chunks, out_channel_block),
lambda n, y, x, fc, fb: conv[n, y, x, fc, fb].astype(out_dtype),
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/topi/adreno/depthwise_conv2d_nchw.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def compute_depthwise_conv2d_NCHWc_KCRSk(Input, Filter, stride, padding, dilatio
tag="depthwise_conv2d_nchwc_kcrsk",
)

if not convert_from4d and autotvm.GLOBAL_SCOPE.in_tuning:
if convert_from4d and not autotvm.GLOBAL_SCOPE.in_tuning:
dummy_cast = te.compute(
(batch, out_channel_chunks, out_height_orig, out_width_orig, out_channel_block),
lambda n, fc, y, x, fb: conv[n, fc, y, x, fb].astype(out_dtype),
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/topi/adreno/depthwise_conv2d_nhwc.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def compute_depthwise_conv2d_NHWC_HWOI(Input, Filter, stride, padding, dilation,
tag="depthwise_conv2d_nhwc",
)

if not convert_from4d and autotvm.GLOBAL_SCOPE.in_tuning:
if convert_from4d and not autotvm.GLOBAL_SCOPE.in_tuning:
dummy_cast = te.compute(
(batch, out_height_orig, out_width_orig, out_channel_chunks, out_channel_block),
lambda n, y, x, fc, fb: conv[n, y, x, fc, fb].astype(out_dtype),
Expand Down
116 changes: 61 additions & 55 deletions python/tvm/topi/adreno/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from ..utils import get_const_tuple


def getDiv(value, start):
def get_div(value, start):
"""Returns the maximum divider for `value` starting from `start` value"""
div = 1
for d in range(start, 0, -1):
Expand Down Expand Up @@ -57,22 +57,25 @@ def split_to_chunks(trip_count, block):
----------
out: tuple of the (chunks, block, tail)
"""
tail = trip_count % 4
chunks = trip_count // 4
tail = trip_count % block
chunks = trip_count // block
if tail == 0:
tail = 4
tail = block
else:
chunks += 1
return chunks, block, tail


def pack_input(
Input, layout, batch, in_channel_chunks, in_channel_block, in_channel_tail, in_height, in_width
Input, layout, batch, chunks, block, original_tail, in_height, in_width
):
"""
Adds compute stages for packing of the data in runtime. Extends channel dimensions
to be dividable by factor 4
This function should be substituted by Schedule.transform_layout() in the future: see
https://github.com/apache/tvm-rfcs/blob/main/rfcs/0039-buffer-physical-layout.md
Parameters
----------
Input: tvm.te.Tensor
Expand All @@ -85,18 +88,18 @@ def pack_input(
batch: int
Batch size
in_channel_chunks: int
chunks: int
Number of channel chunks been in the final tensor
in_channel_block: int
Number of channel blocks been in the final tensor
block: int
size of the channel block
in_channel_tail: int
original_tail: int
Tail in the latest chunk diffing original number of channels vs blocked one
If in_channel_tail != in_channel_block:
original_channels = in_channel_chunks * in_channel_block - in_channel_tail
If original_tail != block:
original_channels = chunks * block - original_tail
else
original_channels = in_channel_chunks * in_channel_block
original_channels = chunks * block
in_height: int
Height of the feature map
Expand All @@ -109,37 +112,37 @@ def pack_input(

def _reorder_data_nchw(*indices):
condition = []
condition.append(indices[1] == in_channel_chunks - 1)
condition.append(indices[4] >= in_channel_tail)
condition.append(indices[1] == chunks - 1)
condition.append(indices[4] >= original_tail)
condition = tvm.tir.all(*condition)
return tvm.tir.if_then_else(
condition,
pad_value,
Input[indices[0], indices[1] * in_channel_block + indices[4], indices[2], indices[3]],
Input[indices[0], indices[1] * block + indices[4], indices[2], indices[3]],
)

def _reorder_data_nhwc(*indices):
condition = []
condition.append(indices[3] == in_channel_chunks - 1)
condition.append(indices[4] >= in_channel_tail)
condition.append(indices[3] == chunks - 1)
condition.append(indices[4] >= original_tail)
condition = tvm.tir.all(*condition)
return tvm.tir.if_then_else(
condition,
pad_value,
Input[indices[0], indices[1], indices[2], indices[3] * in_channel_block + indices[4]],
Input[indices[0], indices[1], indices[2], indices[3] * block + indices[4]],
)

# compute:
if layout == "NCHW":
reordered_data = te.compute(
[batch, in_channel_chunks, in_height, in_width, in_channel_block],
[batch, chunks, in_height, in_width, block],
_reorder_data_nchw,
name="input_pack",
tag="input_pack",
)
elif layout == "NHWC":
reordered_data = te.compute(
[batch, in_height, in_width, in_channel_chunks, in_channel_block],
[batch, in_height, in_width, chunks, block],
_reorder_data_nhwc,
name="input_pack",
tag="input_pack",
Expand All @@ -152,20 +155,23 @@ def _reorder_data_nhwc(*indices):
def pack_filter(
Filter,
layout,
out_channel_chunks,
out_channel_block,
out_channel_tail,
out_chunks,
out_block,
out_original_tail,
in_filter_channels,
in_data_channel_chunks,
in_data_channel_block,
in_data_channel_tail,
in_chunks,
in_block,
in_original_tail,
kernel_h,
kernel_w,
):
"""
Adds compute stages for packing of the filter in runtime. Extends channels dimensions
to be dividable by factor 4
This function should be substituted by Schedule.transform_layout() in the future: see
https://github.com/apache/tvm-rfcs/blob/main/rfcs/0039-buffer-physical-layout.md
Parameters
----------
Filter: tvm.te.Tensor
Expand All @@ -175,26 +181,26 @@ def pack_filter(
Layout of origin 4d tensor
NCHW or NHWC are acceptable
out_channel_chunks: int
out_chunks: int
Number of chunks for filters
out_channel_block: int
Size of the block
out_block: int
Size of the block for output channels
out_channel_tail: int
out_original_tail: int
Original size of the latest chunk of output filters
in_filter_channels: int
Number of filter channels. might be different vs input channels in the
data due to groups/depthwise nature
in_data_channel_chunks: int
Number of chunks by channels for input data
in_chunks: int
Number of input data channel chunks
in_data_channel_block: int
in_block: int
Size of the block for input data channels
in_data_channel_tail
in_original_tail
Original size of the latest chunk for input data channels
kernel_h: int
Expand All @@ -207,75 +213,75 @@ def pack_filter(

def _reorder_weights_depthwise_oihw(*indices):
conditionA = []
conditionA.append(indices[0] == out_channel_chunks - 1)
conditionA.append(indices[4] >= out_channel_tail)
conditionA.append(indices[0] == out_chunks - 1)
conditionA.append(indices[4] >= out_original_tail)
conditionAT = tvm.tir.all(*conditionA)

return tvm.tir.if_then_else(
conditionAT,
pad_value,
Filter[indices[0] * out_channel_block + indices[4], indices[1], indices[2], indices[3]],
Filter[indices[0] * out_block + indices[4], indices[1], indices[2], indices[3]],
)

def _reorder_weights_depthwise_hwoi(*indices):
conditionA = []
conditionA.append(indices[2] == out_channel_chunks - 1)
conditionA.append(indices[4] >= out_channel_tail)
conditionA.append(indices[2] == out_chunks - 1)
conditionA.append(indices[4] >= out_original_tail)
conditionAT = tvm.tir.all(*conditionA)

return tvm.tir.if_then_else(
conditionAT,
pad_value,
Filter[indices[0], indices[1], indices[2] * out_channel_block + indices[4], indices[3]],
Filter[indices[0], indices[1], indices[2] * out_block + indices[4], indices[3]],
)

def _reorder_weights_oihw(*indices):
conditionA = []
conditionA.append(indices[0] == out_channel_chunks - 1)
conditionA.append(indices[4] >= out_channel_tail)
conditionA.append(indices[0] == out_chunks - 1)
conditionA.append(indices[4] >= out_original_tail)
conditionAT = tvm.tir.all(*conditionA)

conditionO = []
conditionO.append(conditionAT)
conditionO.append(
indices[1] >= in_data_channel_chunks * in_data_channel_block + in_data_channel_tail
indices[1] >= in_chunks * in_block + in_original_tail
)
conditionOT = tvm.tir.any(*conditionO)
return tvm.tir.if_then_else(
conditionOT,
pad_value,
Filter[indices[0] * out_channel_block + indices[4], indices[1], indices[2], indices[3]],
Filter[indices[0] * out_block + indices[4], indices[1], indices[2], indices[3]],
)

def _reorder_weights_hwio(*indices):
conditionA = []
conditionA.append(indices[3] == out_channel_chunks - 1)
conditionA.append(indices[4] >= out_channel_tail)
conditionA.append(indices[3] == out_chunks - 1)
conditionA.append(indices[4] >= out_original_tail)
conditionAT = tvm.tir.all(*conditionA)

conditionO = []
conditionO.append(conditionAT)
conditionO.append(
indices[2] >= in_data_channel_chunks * in_data_channel_block + in_data_channel_tail
indices[2] >= in_chunks * in_block + in_original_tail
)
conditionOT = tvm.tir.any(*conditionO)
return tvm.tir.if_then_else(
conditionOT,
pad_value,
Filter[indices[0], indices[1], indices[2], indices[3] * out_channel_block + indices[4]],
Filter[indices[0], indices[1], indices[2], indices[3] * out_block + indices[4]],
)

if in_filter_channels == 1:
if layout == "OIHW":
reordered_filter = te.compute(
[out_channel_chunks, in_filter_channels, kernel_h, kernel_w, out_channel_block],
[out_chunks, in_filter_channels, kernel_h, kernel_w, out_block],
_reorder_weights_depthwise_oihw,
name="filter_pack",
tag="filter_pack",
)
elif layout == "HWOI":
reordered_filter = te.compute(
[kernel_h, kernel_w, out_channel_chunks, in_filter_channels, out_channel_block],
[kernel_h, kernel_w, out_chunks, in_filter_channels, out_block],
_reorder_weights_depthwise_hwoi,
name="filter_pack",
tag="filter_pack",
Expand All @@ -285,14 +291,14 @@ def _reorder_weights_hwio(*indices):
else:
if layout == "OIHW":
reordered_filter = te.compute(
[out_channel_chunks, in_filter_channels, kernel_h, kernel_w, out_channel_block],
[out_chunks, in_filter_channels, kernel_h, kernel_w, out_block],
_reorder_weights_oihw,
name="filter_pack",
tag="filter_pack",
)
elif layout == "HWIO":
reordered_filter = te.compute(
[kernel_h, kernel_w, in_filter_channels, out_channel_chunks, out_channel_block],
[kernel_h, kernel_w, in_filter_channels, out_chunks, out_block],
_reorder_weights_hwio,
name="filter_pack",
tag="filter_pack",
Expand All @@ -317,7 +323,7 @@ def expand_spatial_dimensions(
Height of the feature map
in_width: int
Width of the featrue map
Width of the feature map
kernel_h: int
Height of the conv2d kernel
Expand Down Expand Up @@ -503,7 +509,7 @@ def bind_data_copy(stage, axis_to_vectorize=None):
fused = stage.fuse(ax0, ax1, ax2, oax3)

ftc = numpy.prod(shape) / 4
div = getDiv(ftc, 128)
div = get_div(ftc, 128)
block, thread = stage.split(fused, factor=div)

stage.bind(block, te.thread_axis("blockIdx.z"))
Expand All @@ -513,7 +519,7 @@ def bind_data_copy(stage, axis_to_vectorize=None):
fused = stage.fuse(*axes[:-1])
if shape[-1] <= 32:
ftc = numpy.prod(shape[:-1])
div = getDiv(ftc, 64)
div = get_div(ftc, 64)
block, thread = stage.split(fused, factor=div)
stage.bind(block, te.thread_axis("blockIdx.x"))
stage.bind(thread, te.thread_axis("threadIdx.x"))
Expand Down
5 changes: 4 additions & 1 deletion src/runtime/opencl/opencl_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,10 @@ struct BufferDescriptor {
* e.g. image2d[height=O, width=IHW]
*/
kImage2DWeight,
kTexture2DNHWC,
/*! \brief Two dimensional texture w/ height = axis[1]
* e.g. image2d[height=NH, width=WC]
*/
kImage2DNHWC,
};
BufferDescriptor() = default;
explicit BufferDescriptor(Optional<String> scope) : layout(MemoryLayoutFromScope(scope)) {}
Expand Down
Loading

0 comments on commit 7923f71

Please sign in to comment.