Skip to content

Commit

Permalink
Rever Commits, Start to build minimum Ansor system
Browse files Browse the repository at this point in the history
  • Loading branch information
jcf94 committed Jun 24, 2020
1 parent 86bfd8f commit 910964e
Show file tree
Hide file tree
Showing 8 changed files with 27 additions and 706 deletions.
4 changes: 2 additions & 2 deletions python/tvm/ansor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@

# Shortcut
from .compute_dag import ComputeDAG
from .auto_schedule import SearchTask, SketchSearchPolicy, TuneOption, HardwareParams, \
PreloadMeasuredStates, PreloadCustomSketchRule, auto_schedule
from .auto_schedule import SearchTask, TuneOption, HardwareParams, \
auto_schedule
from .measure import MeasureInput, LocalBuilder, LocalRunner, RPCRunner, LocalRPCMeasureContext
from .cost_model import RandomModel
from .serialization import LogToFile, LogReader, best_measure_pair_in_file, \
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/ansor/compute_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def get_init_state(self):
"""
return State(_ffi_api.ComputeDAGGetInitState(self), self)

def apply_steps_from_state(self, state, layout_rewrite_level=LayoutRewriteLevel.NO_REWRITE):
def apply_steps_from_state(self, state):
"""
Apply transform steps according to the history of a state
Expand Down
20 changes: 20 additions & 0 deletions python/tvm/ansor/loop_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,26 @@ def split(self, stage_id, iterator, lengths, inner_to_outer=True):
self._clear_cache()
return res

def fuse(self, stage_id, iters):
"""
Parameters
----------
stage_id : Union[int, Operation, Tensor]
The index of the stage to fuse
iters : List[Iterator]
The iterators to be fused
Returns
-------
res_it : Iterator
The fused Iterator
"""
stage_id = self._resolve_stage_id(stage_id)

self.state_object, res = _ffi_api.StateFuse(self.state_object, stage_id, iters)
self._clear_cache()
return res

def _resolve_stage_id(self, stage_id):
if isinstance(stage_id, Operation):
return self.stage_id_map[stage_id]
Expand Down
5 changes: 2 additions & 3 deletions src/ansor/compute_dag.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
#include <vector>
#include "transform_step.h"
#include "search_policy/utils.h"
#include "../relay/transforms/kernel_layout_transform.h"

namespace tvm {
namespace ansor {
Expand Down Expand Up @@ -737,7 +736,7 @@ void ComputeDAG::RewriteLayout(
CHECK_EQ(placeholder_axis_names.size(), placeholder->shape.size());
std::string ori_layout = os.str();
os.str("");
::tvm::relay::KernelLayoutVisitor::global_ori_layouts_queue.push_back(ori_layout);
// ::tvm::relay::KernelLayoutVisitor::global_ori_layouts_queue.push_back(ori_layout);
}
}

Expand Down Expand Up @@ -800,7 +799,7 @@ void ComputeDAG::RewriteLayout(
}
std::string new_layout = os.str();
os.str("");
::tvm::relay::KernelLayoutVisitor::global_new_layouts_queue.push_back(new_layout);
// ::tvm::relay::KernelLayoutVisitor::global_new_layouts_queue.push_back(new_layout);
placeholder_new_names[placeholder_op] = new_names;
placeholder_new_shapes[placeholder_op] = new_shape;

Expand Down
59 changes: 0 additions & 59 deletions src/ansor/search_task.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,65 +52,6 @@ HardwareParams HardwareParamsNode::GetDefaultHardwareParams(
if (target->target_name == "llvm") {
return HardwareParams(tvm::runtime::threading::MaxConcurrency(),
32, 64, 16, 64);
} else if (target->device_type == kDLGPU) {
// TODO(jcf94): temp implementation, max vectorize size in GPU is related
// to the data type
auto hardware_params = HardwareParams(100000, 16, 64, 4, 64);
auto* p_hardware_params = hardware_params.CopyOnWrite();

auto ctx = TVMContext{kDLGPU, 0};
auto func = tvm::runtime::Registry::Get("device_api.gpu");
CHECK(func != nullptr) << "Cannot find GPU device_api in registry";
auto device_api =
static_cast<tvm::runtime::DeviceAPI*>(((*func)()).operator void*());

tvm::runtime::TVMRetValue ret;
device_api->GetAttr(
ctx, tvm::runtime::DeviceAttrKind::kMaxSharedMemoryPerBlock, &ret);
p_hardware_params->max_shared_memory_per_block = ret;

device_api->GetAttr(
ctx, tvm::runtime::DeviceAttrKind::kMaxRegistersPerBlock, &ret);
p_hardware_params->max_registers_per_block = ret;

device_api->GetAttr(ctx, tvm::runtime::DeviceAttrKind::kMaxThreadsPerBlock,
&ret);
p_hardware_params->max_threads_per_block = ret;

device_api->GetAttr(ctx, tvm::runtime::DeviceAttrKind::kWarpSize, &ret);
p_hardware_params->warp_size = ret;

// Manually set now
p_hardware_params->max_vthread_extent = 4;

return hardware_params;
} else if (target->device_type == kDLOpenCL) {
// TODO(jcf94): temp implementation
auto hardware_params = HardwareParams(100000, 16, 64, 4, 64);
auto p_hardware_params = hardware_params.CopyOnWrite();

auto ctx = TVMContext{kDLOpenCL, 0};
auto func = tvm::runtime::Registry::Get("device_api.opencl");
CHECK(func != nullptr) << "Cannot find GPU device_api in registry";
auto device_api =
static_cast<tvm::runtime::DeviceAPI*>(((*func)()).operator void*());

tvm::runtime::TVMRetValue ret;
device_api->GetAttr(
ctx, tvm::runtime::DeviceAttrKind::kMaxSharedMemoryPerBlock, &ret);
p_hardware_params->max_shared_memory_per_block = ret;

device_api->GetAttr(ctx, tvm::runtime::DeviceAttrKind::kMaxThreadsPerBlock,
&ret);
p_hardware_params->max_threads_per_block = ret;

device_api->GetAttr(ctx, tvm::runtime::DeviceAttrKind::kWarpSize, &ret);
p_hardware_params->warp_size = ret;

// Manually set now
p_hardware_params->max_vthread_extent = 4;

return hardware_params;
} else {
LOG(FATAL) << "No default hardware parameters for target: " << target;
}
Expand Down
Loading

0 comments on commit 910964e

Please sign in to comment.