Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TOPI] conv2d nchw gpu scheduler #315

Merged
merged 19 commits into from
Aug 14, 2017
Merged

[TOPI] conv2d nchw gpu scheduler #315

merged 19 commits into from
Aug 14, 2017

Conversation

Laurawly
Copy link
Contributor

No description provided.

@tqchen tqchen changed the title mxnet.contrib.topi conv2d gpu scheduler moved to topi addressing mxnet-tvm issue #41 conv2d gpu scheduler Aug 12, 2017
"""Schedule for conv2d_hwcn with auto fusion"""
import tvm


Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

two new lines betwee things

The computation graph description of conv2d_nchw in the format
of a list of tensors.

traget: str
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

target is not needed in here

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since it is under namespace cuda, target is defaults to cuda

@@ -0,0 +1,137 @@
# pylint: disable=invalid-name
"""Schedule for conv2d_nchw with auto fusion, optimized for batch_size(n)=1."""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove optimize for batch_size=1. We might want to provide other schedules later.

s: Schedule
The computation schedule for conv2d_nchw.
"""
s = tvm.create_schedule([x.op for x in outs])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider move this to a internal function with schedule_conv_small_batch. Call that function from here. Check batch size. Raise RuntimrError when batchsize is large

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What limit shall I put for batchsize?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for now we can set it to be 1, and check later if this schedule works well for batchsize bigger than 1

The computation graph description of conv2d_nchw in the format
of a list of tensors.

traget: str
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since it is under namespace cuda, target is defaults to cuda


Parameters
----------
outs: tvm.Array<tvm::Tensor>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change to Array of Tensor, sicne it is python

@tqchen tqchen changed the title conv2d gpu scheduler conv2d nchw gpu scheduler Aug 12, 2017

@tvm.register_func("topi.schedule.cuda.conv2d_hwcn")
def schedule_conv2d_hwcn(outs):
"""Schedule for conv2d_hwcn.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and add note that we can schedule for conv2d_hwcn plus any elementwise operation

"""
s = tvm.create_schedule([x.op for x in outs])
def schedule(Apad, W, B):

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no new line in here


Parameters
----------
outs: Array<Tensor>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Array of Tensor


Parameters
----------
outs: Array<Tensor>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

array of tensor

wfactor=block_h
ifactor=in_filter/4
sfactor=max(1, ofactor/(opart2*2))
spart = int(math.ceil(wfactor/vthread))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is bad, because it is python2/3 dependent. If you want the ceil behaviour, do float(wfactor)/vthread.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

float(wfactor)/vthread doesn't help me to round up.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, you should do int(math.ceil(fkiat(wfactor)/vthread)), or (wfactor + vthread-1) // vthread

if OP not in s.outputs:
s[OP].compute_inline()
for tensor in OP.input_tensors:
if str(tensor.op.input_tensors) != str([]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if tensor.op.input_tensors:

import tvm


def schedule_conv2d_nchw(outs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rename this function as schedule_conv2d_nchw_small_batch. in

def schedule_conv2d_nchw(outs):
       batch_size = tvm.ir_pass.Simplify(outs[0].op.output(0).shape[0]).value
       if batch_size > 1:
          raise RuntimeError("Batch size: %d is too large for this schedule" % batch_size)
        return schedule_conv2d_nchw_small_batch(outs)

@Huyuwei
Copy link
Contributor

Huyuwei commented Aug 13, 2017

I am afraid there are two bugs in schedule_conv2d_nchw.

1. shared memory overflow

input size = (1, 512, 14, 14)
kernel size = (512, 512, 1, 1)
pad = 0, stride = 1

error message:

raise TVMError(py_str(_LIB.TVMGetLastError()))
tvm._ffi.base.TVMError: [14:21:58] src/codegen/llvm/llvm_module.cc:58: Check failed: ret == 0 (-1 vs. 0) [14:21:58] src/runtime/cuda/cuda_module.cc:93: CUDAError: cuModuleLoadData(&(module_[device_id]), data_.c_str()) failed with error: CUDA_ERROR_INVALID_PTX

printed IR:

  // attr [placeholder.shared] storage_scope = "shared"
  allocate placeholder.shared[float32 * 256 * 128 * 1 * 1]

The shared memory size is too large.

@Huyuwei
Copy link
Contributor

Huyuwei commented Aug 13, 2017

2. iter_var(threadIdx.y, , threadIdx.y) domain already inferred, cannot prove their extents are the same

input size = (1, 1024, 7, 7)
kernel size = (1024, 1024, 1, 1)
pad = 0, stride = 1

error message:

/home/hyw/mxnet-tvm/tvm/dmlc-core/include/dmlc/logging.h:308: [14:31:38] src/schedule/message_passing.cc:34: Check failed: match iter_var(threadIdx.y, , threadIdx.y) domain already inferred, cannot prove their extents are the same 4 vs 3

thread_y is bind three times:

s[Out].bind(iiw, thread_y)
s[temp_S].bind(iw, thread_y)
s[Filter_S].bind(ii, thread_y)

the range of iiw, iw, ii should be the same, or it will generate error.

@tqchen
Copy link
Member

tqchen commented Aug 13, 2017

The invalid PTX was caused by shared mem overflow, so they are actually one error

@tqchen
Copy link
Member

tqchen commented Aug 13, 2017

@Huyuwei Can you also post your testcode, so that it can be reused by @Laurawly

@Laurawly
Copy link
Contributor Author

Currently my ifactor and ofactor are dependent of input workload sizes which should be fixed. Thanks for catching that!

@Huyuwei
Copy link
Contributor

Huyuwei commented Aug 13, 2017

All functions are copied into one file:

import tvm
import math
import numpy as np

def get_const_tuple(in_tuple):
    """Verifies input tuple is IntImm, returns tuple of int.

    Parameters
    ----------
    in_tuple : tuple of tvm.expr.IntImm
        The input.

    Returns
    -------
    out_tuple : tuple of int
        The output.
    """
    out_tuple = ()
    for elem in in_tuple:
        if not isinstance(elem, tvm.expr.IntImm):
            raise ValueError("Element of input tuple should be IntImm")
        out_tuple = out_tuple + (elem.value, )
    return out_tuple

@tvm.tag_scope(tag='convolution')
def compute_convolution(data, kernel, HPAD, WPAD, HSTR, WSTR):
    N, IC, H, W = get_const_tuple(data.shape)
    OC, IC, HK, WK = get_const_tuple(kernel.shape)
    TH = H + 2*HPAD
    TW = W + 2*WPAD
    OH = (H + 2*HPAD - HK) / HSTR + 1
    OW = (W + 2*WPAD - WK) / WSTR + 1

    ic = tvm.reduce_axis((0, IC), name='ic')
    dh = tvm.reduce_axis((0, HK), name='dh')
    dw = tvm.reduce_axis((0, WK), name='dw')
    temp = tvm.compute((N, IC, TH, TW), lambda i, ic, h, w: \
        tvm.select(
            tvm.make.Or(tvm.make.Or((h < HPAD), (h >= H + HPAD)),
                        tvm.make.Or((w < WPAD), (w >= W + WPAD))),
            0.0,
            data[i, ic, h - HPAD, w - WPAD]), name='temp')
    return tvm.compute((N, OC, OH, OW), lambda i, oc, h, w: \
        tvm.sum(temp[i, ic, h*HSTR+dh, w*WSTR+dw] * kernel[oc, ic, dh, dw],
                axis=[ic, dh, dw]))


def schedule_conv2d_nchw(outs, target):
    """WIP Schedule for convolution (nchw), optimized for batch_size(n)=1."""
    s = tvm.create_schedule([x.op for x in outs])
    def schedule(temp, Filter, Output):
        out_height = tvm.ir_pass.Simplify(Output.shape[2]).value
        out_width = tvm.ir_pass.Simplify(Output.shape[3]).value
        channel_multiplier = tvm.ir_pass.Simplify(Filter.shape[1]).value

        block_h = out_width
        block_w = tvm.ir_pass.Simplify(temp.shape[1]).value
        if block_h % 48 == 0:
            block_h = 48
        elif block_h % 32 == 0:
            block_h = 32
        if block_w % 48 == 0:
            block_w = 48
        elif block_w % 32 == 0:
            block_w = 32

        s[temp].compute_inline()

        temp_S   = s.cache_read(temp, "shared", [Output])
        Filter_S = s.cache_read(Filter, "shared", [Output])
        temp_L = s.cache_read(temp_S, "local", [Output])
        Filter_L = s.cache_read(Filter_S, "local", [Output])

        if outs[0].op in s.outputs:
            Out = Output
            Out_L = s.cache_write(Out, "local")
        else:
            Out = outs[0].op.output(0)
            s[Output].set_scope("local")
            Out_L = Output

        # sheduler params
        tile = 8
        # num_thread = 8
        step = 16
        vthread = 2
        out_filter = tvm.ir_pass.Simplify(Filter.shape[0]).value
        in_filter = tvm.ir_pass.Simplify(Filter.shape[1]).value
        opart2 = out_filter/8
        ofactor=out_filter
        wfactor=block_h
        ifactor=in_filter/4
        sfactor=max(1, ofactor/(opart2*2))
        spart = int(math.ceil(wfactor/vthread))

        block_x = tvm.thread_axis("blockIdx.x")
        block_y = tvm.thread_axis("blockIdx.y")
        block_z = tvm.thread_axis("blockIdx.z")
        thread_x = tvm.thread_axis("threadIdx.x")
        thread_y = tvm.thread_axis("threadIdx.y")
        thread_xz = tvm.thread_axis((0, vthread), "vthread", name="vx")
        thread_yz = tvm.thread_axis((0, vthread), "vthread", name="vy")

        i, oc, h, w = s[Out].op.axis
        ooc, ioc = s[Out].split(oc, factor=ofactor)
        ow, iw = s[Out].split(w, factor=wfactor)
        ow = s[Out].fuse(ow, h)
        oioc, iioc = s[Out].split(ioc, nparts = vthread)
        oiw, iiw = s[Out].split(iw, nparts=vthread)
        oiioc, iiioc = s[Out].split(iioc, nparts = opart2)
        s[Out].reorder(i, ooc, ow, oioc, oiw, oiioc, iiw, iiioc)
        s[Out].bind(iiioc, thread_x)
        s[Out].bind(iiw, thread_y)
        s[Out].bind(oiioc, thread_xz)
        s[Out].bind(oiw, thread_yz)
        s[Out].bind(oioc, block_x)
        s[Out].bind(ow, block_y)
        s[Out].bind(ooc, block_z)

        s[Out_L].compute_at(s[Out], iiioc)

        # schedule Out_L local write
        i, oc, h, w = s[Out_L].op.axis
        ic, dh, dw = s[Out_L].op.reduce_axis
        oic, iic = s[Out_L].split(ic, factor=ifactor)
        s[Out_L].reorder(oic, dh, dw, iic, h, w)
        fuse_index = s[Out_L].fuse(dw, dh)
        fuse_index = s[Out_L].fuse(fuse_index, oic)
        dw = fuse_index
        s[temp_S].compute_at(s[Out_L], dw)
        s[Filter_S].compute_at(s[Out_L], dw)
        s[temp_L].compute_at(s[Out_L], iic)
        s[Filter_L].compute_at(s[Out_L], iic)
       
        #schedule temp_S shared mem load
        i, ic, h, w = s[temp_S].op.axis
        oic, iic = s[temp_S].split(ic, factor=sfactor)
        _, iw = s[temp_S].split(w, nparts=1)
        ow, iw = s[temp_S].split(iw, factor=spart)
        s[temp_S].bind(iic, thread_x)
        s[temp_S].bind(iw, thread_y)
       
        #schedule Filter_S shared mem load
        i, oc, h, w = s[Filter_S].op.axis
        ooc, ioc = s[Filter_S].split(oc, factor=sfactor)
        _, ii = s[Filter_S].split(i, nparts=1)
        oi, ii = s[Filter_S].split(ii, factor=spart)
        s[Filter_S].bind(ioc, thread_x)
        s[Filter_S].bind(ii, thread_y)
    
    def traverse(OP):
        # inline all one-to-one-mapping operators except the last stage (output)
        if 'ewise' in OP.tag or 'bcast' in OP.tag:
            if OP not in s.outputs:
                s[OP].compute_inline()
            for tensor in OP.input_tensors:
                if str(tensor.op.input_tensors) != str([]):
                    traverse(tensor.op)
        # schedule conv2d
        if 'conv' in OP.tag:
            temp = OP.input_tensors[0]
            Filter = OP.input_tensors[1]
            Output = OP.output(0)
            schedule(temp, Filter, Output)

    traverse(outs[0].op)
    return s


in_channel = 1024
channel_height = 7
channel_width = 7

out_channel = 1024

pad = 0
stride = 1
filter_size = 1

data = tvm.placeholder((1, in_channel, channel_height, channel_width))
weight = tvm.placeholder((out_channel, in_channel, filter_size, filter_size))
conv = [compute_convolution(data, weight, pad, pad, stride, stride),]

schedule = schedule_conv2d_nchw(conv, "cuda")
print tvm.lower(schedule, [data, weight, conv[0]], simple_mode=True)

f = tvm.build(schedule, [data, weight, conv[0]], "cuda")

data_np = np.random.uniform(size=get_const_tuple(data.shape)).astype(data.dtype)
weight_np = np.random.uniform(size=get_const_tuple(weight.shape)).astype(weight.dtype)
data_tvm = tvm.nd.array(data_np, tvm.gpu(0))
weight_tvm = tvm.nd.array(weight_np, tvm.gpu(0))
conv_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(conv[0].shape), dtype=conv[0].dtype), tvm.gpu(0))

timer = f.time_evaluator(f.entry_name, tvm.gpu(0), number=1)
tcost = timer(data_tvm, weight_tvm, conv_tvm).mean

print("average time cost of 1 runs (conv) = %g sec" % tcost)

@tqchen
Copy link
Member

tqchen commented Aug 13, 2017

Will merge after the changes are made to support larger workloads and unittestcases added

@tqchen tqchen changed the title conv2d nchw gpu scheduler [TOPI] conv2d nchw gpu scheduler Aug 13, 2017
@tqchen
Copy link
Member

tqchen commented Aug 14, 2017

@Huyuwei can you verify the commit and approve if it passes your review?

s: Schedule
The computation schedule for conv2d_hwcn.
"""
sch = tvm.create_schedule([x.op for x in outs])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add the following line before this, so outs can also be a single tensor

outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do the same thing for nchw schedule


def schedule(temp, Filter, Output):
"""Schedule conv2d_nchw"""
block_h = tvm.ir_pass.Simplify(Output.shape[3]).value
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use

from .. import util

block_h = util.get_const_int(Out.shape[3])

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please rebase after #319 get merged

num_thread = 8
vthread = 2
out_filter = min(64, tvm.ir_pass.Simplify(Filter.shape[0]).value)
in_filter = tvm.ir_pass.Simplify(Filter.shape[1]).value
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here for most cases we shouldn't explicitly use Simplfy

s: Schedule
The computation schedule for conv2d_nchw.
"""
batch_size = tvm.ir_pass.Simplify(outs[0].op.output(0).shape[0]).value
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

@tqchen tqchen merged commit cbff637 into apache:master Aug 14, 2017
@tqchen
Copy link
Member

tqchen commented Aug 14, 2017

Thanks this is merged. There are a few things we need from followup PR. This PR creates a good standard for the schedule interface, which should take Array of Tensors(or a single Tensor) instead of op. Also in terms of naming, map suffix is removed from functions and file names, and we assume all complex functions like conv is able to schedule with follwup ewise ops

@Huyuwei please update the depthwise part to reflect this

@Huyuwei
Copy link
Contributor

Huyuwei commented Aug 14, 2017

@tqchen Got it.

@arassadin
Copy link

arassadin commented Mar 19, 2018

Hi everyone.

I got such error reproducing toy example from nnvm but with my own model. Calling

m.run()

I get the error similar to #315 (comment):

TVMError: [09:11:33] src/runtime/cuda/cuda_module.cc:93: CUDAError: cuModuleLoadData(&(module_[device_id]), data_.c_str()) failed with error: CUDA_ERROR_INVALID_PTX

Can you clarify me what can be wrong now?

Thanks in advance!


BTW, I'm a bit confused by tvm.gpu() docstring 😃:

Construct a CPU device

@arassadin
Copy link

@tqchen , @Laurawly , @Huyuwei can you take a quick look?..

@tqchen
Copy link
Member

tqchen commented Mar 20, 2018

@arassadin please open new issues for new questions, in your case, it is likely the gpu schedule for nchw did not work for your specific shape of conv2d and the nvcc compiler failed to compile

@arassadin
Copy link

Ok, thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants