Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] SSD fully supported on GPUs, updated deploy_ssd tutorial #2510

Merged
merged 12 commits into from
Jan 30, 2019

Conversation

Laurawly
Copy link
Contributor

Thanks to @vinx13 's pr #2420, argsort working now on GPUs.
Tested SSD full pipeline on NVIDIA K80c and Intel HD graphics. Performance improved compared with heterogenous results.
Please review @masahi @kevinthesun @zhiics

@masahi masahi self-assigned this Jan 25, 2019
#ctx = tvm.gpu(0)
# Use these commented settings to build for opencl.
#target = 'opencl'
#ctx = tvm.gpu(0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if I remember correctly, for opencl it should be tvm.opencl(0) or tvm.cl(0), isn't it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, sry I forgot to change.

@masahi
Copy link
Member

masahi commented Jan 25, 2019

@Laurawly @vinx13 can we share the sorting IR in this PR and @vinx13's PR #2420? They look identical.

@Laurawly
Copy link
Contributor Author

I agree to put sort in a common file. And we can add a unitest for it as well.

with ib.for_range(0, batch, for_type="unroll") as b:
start = b * num_anchors
with ib.if_scope(tid < num_anchors):
p_out[start + tid] = tid
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems storage_sync is missing here, I will update my pr

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vinx13 Would you like to seperate argsort to a seperate file so that we can share the use of it? I can add unitest to it if needed.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Laurawly What's needed in ssd? Seems that you changed num_bbox in my pr to p_index[0], why only first element in p_index is used?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can make argsort a normal topi op? I'll add cpu implementation later.

Copy link
Contributor Author

@Laurawly Laurawly Jan 28, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vinx13 p_index is the valid_count variable which is a 1D array resulted from the multibox operators. So instead of sorting all of data.shape[1] numbers, we only need to sort the first p_index[0] numbers.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Laurawly shouldn't be p_index[batch_id]? are you assuming batch = 1?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vinx13 p_index only have one dimension. So it should be p_index[0].

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kevinthesun @Laurawly The difficulty of sharing argsort (or extract it as a topi operator) is that we hope sort_num can be either a tvm.Tensor or constant array, but we can't use tvm.Expr to subscript a python array. Do you have ideas?

@yzhliu yzhliu mentioned this pull request Jan 28, 2019
with ib.else_scope():
start = sizes[tid-1]
p_out[base_idx + k * axis_mul_after] = tvm.if_then_else(
k < p_index[tid], index_new[k+start], k)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Laurawly still confused, if batch > 1, it should enter this if branch (since axis_mul_before * axis_mul_after > 1). Does p_index[tid] here mean that each batch has a different valid count?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vinx13 From here https://github.com/dmlc/tvm/blob/master/topi/python/topi/cuda/nms.py#L368 axis is always 1, so axis_mul_before and axis_mul_after are both 1.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Laurawly since ndim == 2, axis == 1, the actual loop is like

for i in range(0, 2):
  if i < 1:
     axis_mul_before *= data.shape[i]

I assume that axis_mul_after == 1, axis_mul_before == data.shape[0], which is batch size, right?

Copy link
Contributor Author

@Laurawly Laurawly Jan 28, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vinx13 Yeah, that's right. I see what you mean. So each batch could have a different valid count when batch_size > 1. I shouldn't have assumed batch_size = 1. I just pushed the changes.

@vinx13
Copy link
Member

vinx13 commented Jan 28, 2019

@Laurawly Btw have you checked the data race in nms ir? Seems __syncthreads and global barrier (maybe we should rewrite the ir to avoid global barrier) are needed on CUDA. I sometimes get incorrect nms results in my pr.

@Laurawly
Copy link
Contributor Author

@vinx13 Does the conflict happen in argsort_ir?

@vinx13
Copy link
Member

vinx13 commented Jan 29, 2019

@Laurawly the conflict happens in nms_ir, I replaced blockIdx with vthread and added storage_sync and it worked, but my current solution is not efficient

@Laurawly
Copy link
Contributor Author

@vinx13 I don't see conflicts in my nms_ir using blockIdx.x, but I'll double check. Why do you want to replace blockIdx.x with vthread?

@vinx13
Copy link
Member

vinx13 commented Jan 29, 2019

@Laurawly If the data written by other threads is needed (probably this line if_scope(p_out[b * num_anchors * 6 + offset_l] >= 0), there may be data race because lack of synchronization.

@Laurawly
Copy link
Contributor Author

@vinx13 There's no data conflict for p_out in SSD because nms_topk = -1. For the line you mentioned, because writing and reading of p_out happen in the same thread block (if_scope(p_out[b * num_anchors * 6 + offset_i] >= 0) and p_out[b * num_anchors * 6 + offset_i] = -1.0)

@vinx13
Copy link
Member

vinx13 commented Jan 30, 2019

@Laurawly the writing p_out[base_idx + offset_i] = -1.0 is in thread i, while the reading of p_out[base_idx + offset_l] in ib.if_scope(p_out[base_idx + offset_l] >= 0) and p_out[base_idx + offset_l] == p_out[b * num_anchors * 6 + offset_i])) is in all threads. And there is no synchrozation after each iteration of for_range(0, p_valid_count[b]) as l, is there confilict in this case?

@Laurawly
Copy link
Contributor Author

Laurawly commented Jan 30, 2019

@vinx13 No, because there's a condition that i > l. And because you iterate via l sequentially, when you read from p_out[base_idx + offset_l], the writing of it should be already finished. Such as when l==0, for thread 0, it couldn't write to p_out[base_idx + offset_0] so it returned; for thread 1 to thread n, they need to read from p_out[base_idx + offset_0] but we have already know thread 0 won't write to it. When l==1, for thread 0, 1, they won't write to p_out, for thread 2, it may need to read p_out[base_idx + offset_1] before thread 1 finish writing to it in iteration l==0. But in that case, it means p_out[base_idx + offset_1] == -1 and p_out[base_idx + offset_2] != -1 because thread 2 finishes much earlier than 'thread 1'. So p_out[base_idx + offset_l] ==p_out[base_idx + offset_i] won't be true either way.

@vinx13
Copy link
Member

vinx13 commented Jan 30, 2019

@Laurawly I see, thanks for your clarification

@masahi masahi merged commit 48c16a1 into apache:master Jan 30, 2019
@masahi
Copy link
Member

masahi commented Jan 30, 2019

thanks @Laurawly @vinx13 @kevinthesun @zhiics this is merged.

merrymercy pushed a commit to merrymercy/tvm that referenced this pull request Feb 18, 2019
…ache#2510)

* nms fixed for gpu, tested on cuda and opencl devices, ssd now can run fully on the gpu

* sort updated to use virtual thread

* typo fixed

* fix lint

* fix lint

* add support when batch_size > 1

* intel graphics conv2d bugs fixed for inception_v3

* intel conv2d api updated, nn input size 4 condition added

* review addressed

* move conv_tags to attributes

* opencl ctx fixed

* nms_ir index simplified
wweic pushed a commit to neo-ai/tvm that referenced this pull request Feb 20, 2019
…ache#2510)

* nms fixed for gpu, tested on cuda and opencl devices, ssd now can run fully on the gpu

* sort updated to use virtual thread

* typo fixed

* fix lint

* fix lint

* add support when batch_size > 1

* intel graphics conv2d bugs fixed for inception_v3

* intel conv2d api updated, nn input size 4 condition added

* review addressed

* move conv_tags to attributes

* opencl ctx fixed

* nms_ir index simplified
wweic pushed a commit to neo-ai/tvm that referenced this pull request Feb 20, 2019
…ache#2510)

* nms fixed for gpu, tested on cuda and opencl devices, ssd now can run fully on the gpu

* sort updated to use virtual thread

* typo fixed

* fix lint

* fix lint

* add support when batch_size > 1

* intel graphics conv2d bugs fixed for inception_v3

* intel conv2d api updated, nn input size 4 condition added

* review addressed

* move conv_tags to attributes

* opencl ctx fixed

* nms_ir index simplified
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants