Skip to content

Commit

Permalink
removed im2col profiling for conv2d
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Dec 11, 2021
1 parent b724f44 commit f7d17a1
Showing 1 changed file with 36 additions and 56 deletions.
92 changes: 36 additions & 56 deletions python/tvm/contrib/cutlass/gen_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,59 +169,39 @@ def profile(
If profile_all is False, return immediately after the first applicable kernel is found.
If use_multiprocessing is True, compile all profiler executables in parallel.
"""
if True:
ops = GENERATOR_FUNC_TABLE[self.sm](out_dtype, op_creator=create_conv2d_operator)
N, H, W, IC = d_shape
OC, R, S, _ = w_shape
ops = list(filter(lambda op: self.check_align(op["name"], IC, OC), ops))

for op in ops:
op["runtime"] = -1

if profile_all:
self.engine.compile_all(ops, use_multiprocessing)

args = [
"--n=%d" % N,
"--h=%d" % H,
"--w=%d" % W,
"--k=%d" % OC,
"--c=%d" % IC,
"--r=%d" % R,
"--s=%d" % S,
"--pad_h=%d" % padding[0],
"--pad_w=%d" % padding[1],
"--stride_h=%d" % stride[0],
"--stride_w=%d" % stride[1],
"--dilation_h=%d" % dilation[0],
"--dilation_w=%d" % dilation[1],
]
for op in ops:
out = self.engine.evaluate(op, args)
op["runtime"] = out
if out > 0 and profile_all is False:
break

valid_ops = filter(lambda op: op["runtime"] > 0, ops)
output = sorted(valid_ops, key=lambda i: i["runtime"])
# self.cache[(M, N, K)] = output[0]
return output[0]

else:
B, _, _, IC = d_shape
OC, R, S, _ = w_shape
_, P, Q, _ = out_shape

M = B * P * Q
N = OC
K = R * S * IC

gemm_profile_result = self.gemm_profiler.profile(
M, N, K, out_dtype, profile_all=profile_all, use_multiprocessing=use_multiprocessing
)

tile_description = gemm_profile_result["tile_description"]
alignment = gemm_profile_result["alignment"]
data_type = gemm_profile_result["data_type"]

return create_conv2d_operator([tile_description], data_type, [alignment])[0]
ops = GENERATOR_FUNC_TABLE[self.sm](out_dtype, op_creator=create_conv2d_operator)
N, H, W, IC = d_shape
OC, R, S, _ = w_shape
ops = list(filter(lambda op: self.check_align(op["name"], IC, OC), ops))

for op in ops:
op["runtime"] = -1

if profile_all:
self.engine.compile_all(ops, use_multiprocessing)

args = [
"--n=%d" % N,
"--h=%d" % H,
"--w=%d" % W,
"--k=%d" % OC,
"--c=%d" % IC,
"--r=%d" % R,
"--s=%d" % S,
"--pad_h=%d" % padding[0],
"--pad_w=%d" % padding[1],
"--stride_h=%d" % stride[0],
"--stride_w=%d" % stride[1],
"--dilation_h=%d" % dilation[0],
"--dilation_w=%d" % dilation[1],
]
for op in ops:
out = self.engine.evaluate(op, args)
op["runtime"] = out
if out > 0 and profile_all is False:
break

valid_ops = filter(lambda op: op["runtime"] > 0, ops)
output = sorted(valid_ops, key=lambda i: i["runtime"])
# self.cache[(M, N, K)] = output[0]
return output[0]

0 comments on commit f7d17a1

Please sign in to comment.