diff --git a/python/tvm/autotvm/measure/measure_methods.py b/python/tvm/autotvm/measure/measure_methods.py index 9334b94b7cf94..168a8ba1757cb 100644 --- a/python/tvm/autotvm/measure/measure_methods.py +++ b/python/tvm/autotvm/measure/measure_methods.py @@ -227,7 +227,8 @@ def set_task(self, task): def get_build_kwargs(self): kwargs = {} - if 'cuda' in self.task.target.keys or 'opencl' in self.task.target.keys: + if 'cuda' in self.task.target.keys or 'opencl' in self.task.target.keys or \ + 'rocm in self.task.target.keys': remote = request_remote(self.key, self.host, self.port) ctx = remote.context(str(self.task.target), 0) max_dims = ctx.max_thread_dimensions @@ -345,7 +346,6 @@ def set_task(self, task): def _build_func_common(measure_input, check_gpu=None, cuda_arch=None, build_option=None): """Common part for building a configuration""" target, task, config = measure_input - with target: s, args = task.instantiate(config)