diff --git a/include/tvm/auto_scheduler/cost_model.h b/include/tvm/auto_scheduler/cost_model.h index 4a22b4708fded..856c907695eff 100644 --- a/include/tvm/auto_scheduler/cost_model.h +++ b/include/tvm/auto_scheduler/cost_model.h @@ -42,7 +42,7 @@ using runtime::TypedPackedFunc; class CostModelNode : public Object { public: /*! - * \brief Update the cost model according to new measurement pairs (training data). + * \brief Update the cost model according to new measurement results (training data). * \param inputs The measure inputs * \param results The measure results */ @@ -50,7 +50,7 @@ class CostModelNode : public Object { /*! * \brief Predict the scores of states - * \param task The search task + * \param task The search task of states * \param states The input states * \param scores The predicted scores for all states */ diff --git a/python/tvm/auto_scheduler/cost_model/cost_model.py b/python/tvm/auto_scheduler/cost_model/cost_model.py index 9a366fab51b1d..bb9af255fc2c2 100644 --- a/python/tvm/auto_scheduler/cost_model/cost_model.py +++ b/python/tvm/auto_scheduler/cost_model/cost_model.py @@ -28,18 +28,40 @@ class CostModel(Object): """The base class for cost model""" +@tvm._ffi.register_object("auto_scheduler.RandomModel") +class RandomModel(CostModel): + """A model returns random estimation for all inputs""" + def __init__(self): + self.__init_handle_by_constructor__(_ffi_api.RandomModel) + def update(self, inputs, results): + """Update the cost model according to new measurement results (training data). + + Parameters + ---------- + inputs : List[MeasureInput] + The measurement inputs + results : List[MeasureResult] + The measurement results + """ _ffi_api.CostModelUpdate(self, inputs, results) def predict(self, search_task, states): - return _ffi_api.CostModelPredict(self, search_task, states) + """Predict the scores of states + Parameters + ---------- + search_task : SearchTask + The search task of states + statse : List[State] + The input states -@tvm._ffi.register_object("auto_scheduler.RandomModel") -class RandomModel(CostModel): - """A model returns random estimation for all inputs""" - def __init__(self): - self.__init_handle_by_constructor__(_ffi_api.RandomModel) + Returns + ------- + scores: List[float] + The predicted scores for all states + """ + return [x.value for x in _ffi_api.CostModelPredict(self, search_task, states)] @tvm._ffi.register_func("auto_scheduler.cost_model.random_number") @@ -74,10 +96,47 @@ def predict_stage_func(task, states, return_ptr): predict_func, predict_stage_func) def update(self, inputs, results): + """Update the cost model according to new measurement results (training data). + + Parameters + ---------- + inputs : List[MeasureInput] + The measurement inputs + results : List[MeasureResult] + The measurement results + """ raise NotImplementedError def predict(self, task, states): + """Predict the scores of states + + Parameters + ---------- + search_task : SearchTask + The search task of states + statse : List[State] + The input states + + Returns + ------- + scores: List[float] + The predicted scores for all states + """ raise NotImplementedError def predict_stages(self, task, states): + """Predict the scores of states + + Parameters + ---------- + search_task : SearchTask + The search task of states + statse : List[State] + The input states + + Returns + ------- + scores: List[float] + The predicted scores for all stages in all states in packed format + """ raise NotImplementedError diff --git a/tests/python/unittest/test_auto_scheduler_cost_model.py b/tests/python/unittest/test_auto_scheduler_cost_model.py index e45537b18fbaa..baad4bb03765e 100644 --- a/tests/python/unittest/test_auto_scheduler_cost_model.py +++ b/tests/python/unittest/test_auto_scheduler_cost_model.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -"""Test cost model""" +"""Test cost models""" import tvm from tvm import auto_scheduler