Skip to content

Commit

Permalink
[AutoTVM] Support range in index based tuners (#4870)
Browse files Browse the repository at this point in the history
* Support range in index based tuners

* Address comments

* Remove __*state__

* trigger CI
  • Loading branch information
comaniac authored Feb 15, 2020
1 parent a5e54b1 commit feda150
Show file tree
Hide file tree
Showing 6 changed files with 200 additions and 104 deletions.
2 changes: 1 addition & 1 deletion python/tvm/autotvm/tuner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,6 @@

from .tuner import Tuner

from .gridsearch_tuner import GridSearchTuner, RandomTuner
from .index_based_tuner import GridSearchTuner, RandomTuner
from .ga_tuner import GATuner
from .xgboost_tuner import XGBTuner
85 changes: 0 additions & 85 deletions python/tvm/autotvm/tuner/gridsearch_tuner.py

This file was deleted.

110 changes: 110 additions & 0 deletions python/tvm/autotvm/tuner/index_based_tuner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=abstract-method
"""Grid search tuner and random tuner"""

import numpy as np

from .tuner import Tuner

class IndexBaseTuner(Tuner):
"""Base class for index based tuner
This type of tuner determine the next batch of configs based on config indices.
Parameters
----------
task: autotvm.task.Task
The tuning task
range_idx: Optional[Tuple[int, int]]
A tuple of index range that this tuner can select from
"""
def __init__(self, task, range_idx=None):
super(IndexBaseTuner, self).__init__(task)
assert range_idx is None or isinstance(range_idx, tuple), \
"range_idx must be None or (int, int)"

self.range_length = len(self.task.config_space)
self.index_offset = 0
if range_idx is not None:
assert range_idx[1] > range_idx[0], "Index range must be positive"
assert range_idx[0] >= 0, "Start index must be positive"
self.range_length = range_idx[1] - range_idx[0] + 1
self.index_offset = range_idx[0]
self.counter = 0

def has_next(self):
return self.counter < self.range_length

def load_history(self, data_set):
pass


class GridSearchTuner(IndexBaseTuner):
"""Enumerate the search space in a grid search order"""

def next_batch(self, batch_size):
ret = []
for _ in range(batch_size):
if self.counter >= self.range_length:
break
index = self.counter + self.index_offset
ret.append(self.task.config_space.get(index))
self.counter = self.counter + 1
return ret


class RandomTuner(IndexBaseTuner):
"""Enumerate the search space in a random order
Parameters
----------
task: autotvm.task.Task
Tuning Task
range_idx: Optional[Tuple[int, int]]
A tuple of index range to random
"""
def __init__(self, task, range_idx=None):
super(RandomTuner, self).__init__(task, range_idx)

# Use a dict to mimic a range(n) list without storing rand_state[i] = i entries so that
# we can generate non-repetitive random indices.
self.rand_state = {}
self.rand_max = self.range_length
self.visited = []

def next_batch(self, batch_size):
ret = []
for _ in range(batch_size):
if self.rand_max == 0:
break

# Random an indirect index.
index_ = np.random.randint(self.rand_max)
self.rand_max -= 1

# Use the indirect index to get a direct index.
index = self.rand_state.get(index_, index_) + self.index_offset
ret.append(self.task.config_space.get(index))
self.visited.append(index)

# Update the direct index map.
self.rand_state[index_] = self.rand_state.get(self.rand_max, self.rand_max)
self.rand_state.pop(self.rand_max, None)
self.counter += 1
return ret
16 changes: 15 additions & 1 deletion tests/python/unittest/test_autotvm_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,24 @@
"""Common utilities for testing autotvm"""
import time

import numpy as np

import tvm
from tvm import autotvm
from tvm.autotvm import MeasureInput, MeasureResult
from tvm.autotvm.measure.measure import Runner


class DummyRunner(Runner):
def __init__(self):
super(DummyRunner, self).__init__(1, 1)

def run(self, measure_inputs, build_results):
return [MeasureResult((np.random.random(),), 0, 0.2, time.time())
for _ in range(len(measure_inputs))]

def get_build_kwargs(self):
return {}

@autotvm.template
def matmul(N, L, M, dtype):
Expand Down Expand Up @@ -82,4 +97,3 @@ def get_sample_records(n):
inps.append(MeasureInput(target, tsk, tsk.config_space.get(i)))
ress.append(MeasureResult((i+1,), 0, i, time.time()))
return list(zip(inps, ress))

68 changes: 68 additions & 0 deletions tests/python/unittest/test_autotvm_index_tuner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Test index based tuners"""

from test_autotvm_common import DummyRunner, get_sample_task
from tvm import autotvm
from tvm.autotvm.tuner import GridSearchTuner, RandomTuner


def test_gridsearch_tuner():
"""Test GridSearchTuner"""

task, _ = get_sample_task()
measure_option = autotvm.measure_option(builder=autotvm.LocalBuilder(), runner=DummyRunner())

# When no range index, range_length should be the length of config space
tuner = autotvm.tuner.GridSearchTuner(task)
assert tuner.range_length == len(task.config_space)
assert tuner.index_offset == 0

# With range index, range_length should be the length of the specified range
tuner = autotvm.tuner.GridSearchTuner(task, range_idx=(8, 15))
assert tuner.range_length == 8
assert tuner.index_offset == 8

# Tuner should only focus on the specified range
tuner.tune(n_trial=8, measure_option=measure_option)
assert tuner.counter == 8
assert not tuner.has_next()


def test_random_tuner():
"""Test RandomTuner"""

task, _ = get_sample_task()
measure_option = autotvm.measure_option(builder=autotvm.LocalBuilder(), runner=DummyRunner())

tuner = autotvm.tuner.RandomTuner(task, range_idx=(8, 15))
assert tuner.range_length == 8
assert tuner.index_offset == 8

# Tuner should only focus on the specified range and should visit all indices
tuner.tune(n_trial=8, measure_option=measure_option)
assert tuner.counter == 8
assert not tuner.has_next()
visited = set()
for idx in tuner.visited:
assert idx not in visited
assert 8 <= idx <= 15


if __name__ == '__main__':
test_gridsearch_tuner()
test_random_tuner()
23 changes: 6 additions & 17 deletions tests/python/unittest/test_autotvm_measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,24 +21,14 @@
import numpy as np

import tvm
from test_autotvm_common import DummyRunner, bad_matmul, get_sample_task
from tvm import autotvm
from test_autotvm_common import get_sample_task, bad_matmul
from tvm.autotvm.measure.measure import Runner, MeasureResult, MeasureErrorNo
from tvm.autotvm.measure.measure import MeasureErrorNo, MeasureResult


def test_task_tuner_without_measurement():
"""test task and tuner without measurement"""
task, target = get_sample_task()

class DummyRunner(Runner):
def __init__(self):
super(DummyRunner, self).__init__(1, 1)

def run(self, measure_inputs, build_results):
return [MeasureResult((np.random.random(),), 0, 0.2, time.time())
for _ in range(len(measure_inputs))]

def get_build_kwargs(self):
return {}
task, _ = get_sample_task()

measure_option = autotvm.measure_option(
builder=autotvm.LocalBuilder(),
Expand All @@ -64,7 +54,7 @@ def test_check_correctness():
)

def _callback_correct(tuner, measure_inputs, measure_results):
for inp, res in zip(measure_inputs, measure_results):
for _, res in zip(measure_inputs, measure_results):
assert res.error_no == 0

tuner = autotvm.tuner.RandomTuner(task)
Expand All @@ -77,7 +67,7 @@ def _callback_correct(tuner, measure_inputs, measure_results):
task = autotvm.task.create(bad_matmul, args=(n, n, n, 'float32'), target=target)

def _callback_wrong(tuner, measure_inputs, measure_results):
for inp, res in zip(measure_inputs, measure_results):
for _, res in zip(measure_inputs, measure_results):
assert res.error_no == MeasureErrorNo.WRONG_ANSWER

tuner = autotvm.tuner.RandomTuner(task)
Expand All @@ -90,4 +80,3 @@ def _callback_wrong(tuner, measure_inputs, measure_results):

test_task_tuner_without_measurement()
test_check_correctness()

0 comments on commit feda150

Please sign in to comment.