Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy committed Aug 2, 2020
1 parent 5fc5154 commit 9203263
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 9 deletions.
4 changes: 2 additions & 2 deletions include/tvm/auto_scheduler/cost_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,15 @@ using runtime::TypedPackedFunc;
class CostModelNode : public Object {
public:
/*!
* \brief Update the cost model according to new measurement pairs (training data).
* \brief Update the cost model according to new measurement results (training data).
* \param inputs The measure inputs
* \param results The measure results
*/
virtual void Update(const Array<MeasureInput>& inputs, const Array<MeasureResult>& results) = 0;

/*!
* \brief Predict the scores of states
* \param task The search task
* \param task The search task of states
* \param states The input states
* \param scores The predicted scores for all states
*/
Expand Down
71 changes: 65 additions & 6 deletions python/tvm/auto_scheduler/cost_model/cost_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,40 @@
class CostModel(Object):
"""The base class for cost model"""

@tvm._ffi.register_object("auto_scheduler.RandomModel")
class RandomModel(CostModel):
"""A model returns random estimation for all inputs"""
def __init__(self):
self.__init_handle_by_constructor__(_ffi_api.RandomModel)

def update(self, inputs, results):
"""Update the cost model according to new measurement results (training data).
Parameters
----------
inputs : List[MeasureInput]
The measurement inputs
results : List[MeasureResult]
The measurement results
"""
_ffi_api.CostModelUpdate(self, inputs, results)

def predict(self, search_task, states):
return _ffi_api.CostModelPredict(self, search_task, states)
"""Predict the scores of states
Parameters
----------
search_task : SearchTask
The search task of states
statse : List[State]
The input states
@tvm._ffi.register_object("auto_scheduler.RandomModel")
class RandomModel(CostModel):
"""A model returns random estimation for all inputs"""
def __init__(self):
self.__init_handle_by_constructor__(_ffi_api.RandomModel)
Returns
-------
scores: List[float]
The predicted scores for all states
"""
return [x.value for x in _ffi_api.CostModelPredict(self, search_task, states)]


@tvm._ffi.register_func("auto_scheduler.cost_model.random_number")
Expand Down Expand Up @@ -74,10 +96,47 @@ def predict_stage_func(task, states, return_ptr):
predict_func, predict_stage_func)

def update(self, inputs, results):
"""Update the cost model according to new measurement results (training data).
Parameters
----------
inputs : List[MeasureInput]
The measurement inputs
results : List[MeasureResult]
The measurement results
"""
raise NotImplementedError

def predict(self, task, states):
"""Predict the scores of states
Parameters
----------
search_task : SearchTask
The search task of states
statse : List[State]
The input states
Returns
-------
scores: List[float]
The predicted scores for all states
"""
raise NotImplementedError

def predict_stages(self, task, states):
"""Predict the scores of states
Parameters
----------
search_task : SearchTask
The search task of states
statse : List[State]
The input states
Returns
-------
scores: List[float]
The predicted scores for all stages in all states in packed format
"""
raise NotImplementedError
2 changes: 1 addition & 1 deletion tests/python/unittest/test_auto_scheduler_cost_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.

"""Test cost model"""
"""Test cost models"""

import tvm
from tvm import auto_scheduler
Expand Down

0 comments on commit 9203263

Please sign in to comment.