Skip to content

Commit

Permalink
add the base class for cost models
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy committed Aug 2, 2020
1 parent 1823a7b commit 5a49ec7
Show file tree
Hide file tree
Showing 6 changed files with 520 additions and 1 deletion.
160 changes: 160 additions & 0 deletions include/tvm/auto_scheduler/cost_model.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/auto_scheduler/cost_model.h
* \brief Cost models that estimate the performance of programs
*/

#ifndef TVM_AUTO_SCHEDULER_COST_MODEL_H_
#define TVM_AUTO_SCHEDULER_COST_MODEL_H_

#include <tvm/auto_scheduler/compute_dag.h>
#include <tvm/auto_scheduler/measure.h>
#include <tvm/node/node.h>
#include <tvm/runtime/packed_func.h>

#include <vector>

namespace tvm {
namespace auto_scheduler {

using runtime::PackedFunc;
using runtime::TypedPackedFunc;

/*! \brief The base class for cost model */
class CostModelNode : public Object {
public:
/*!
* \brief Update the cost model according to new measurement results (training data).
* \param inputs The measure inputs
* \param results The measure results
*/
virtual void Update(const Array<MeasureInput>& inputs, const Array<MeasureResult>& results) = 0;

/*!
* \brief Predict the scores of states
* \param task The search task of states
* \param states The input states
* \param scores The predicted scores for all states
*/
virtual void Predict(const SearchTask& task, const std::vector<State>& states,
std::vector<float>* scores) = 0;

/*!
* \brief Predict the scores of all stages in states
* \param task The search task
* \param states The input states
* \param state_scores The predicted scores for all states
* \param stage_scores The predicted scores for all stages in all stages
*/
virtual void PredictStages(const SearchTask& task, const std::vector<State>& states,
std::vector<float>* state_scores,
std::vector<std::vector<float>>* stage_scores) {
LOG(FATAL) << "Not implemented";
}

static constexpr const char* _type_key = "auto_scheduler.CostModel";
TVM_DECLARE_BASE_OBJECT_INFO(CostModelNode, Object);
};

/*!
* \brief Managed reference to CostModelNode.
* \sa CostModelNode
*/
class CostModel : public ObjectRef {
public:
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(CostModel, ObjectRef, CostModelNode);
};

/*! \brief The cost model returning random value for all predictions */
class RandomModelNode : public CostModelNode {
public:
/*! \brief Pointer to a random number generator function */
const TypedPackedFunc<void(size_t, void*)>* random_number_func;

void Update(const Array<MeasureInput>& inputs, const Array<MeasureResult>& results) final;

void Predict(const SearchTask& task, const std::vector<State>& states,
std::vector<float>* scores) final;

static constexpr const char* _type_key = "auto_scheduler.RandomModel";
TVM_DECLARE_FINAL_OBJECT_INFO(RandomModelNode, CostModelNode);
};

/*!
* \brief Managed reference to RandomModelNode.
* \sa RandomModelNode
*/
class RandomModel : public CostModel {
public:
RandomModel();
explicit RandomModel(::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) : CostModel(n) {}

RandomModelNode* operator->() const { return static_cast<RandomModelNode*>(data_.get()); }

TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(RandomModel);
using ContainerType = RandomModelNode;
};

/*! \brief A wrapper for cost model defined by python code
* This class will call functions defined in the python */
class PythonBasedModelNode : public CostModelNode {
public:
/*! \brief Pointer to the update funcion in python */
PackedFunc update_func;
/*! \brief Pointer to the predict funcion in python */
PackedFunc predict_func;
/*! \brief Pointer to the predict funcion in python */
PackedFunc predict_stage_func;

void Update(const Array<MeasureInput>& inputs, const Array<MeasureResult>& results) final;

void Predict(const SearchTask& task, const std::vector<State>& states,
std::vector<float>* scores) final;

void PredictStages(const SearchTask& task, const std::vector<State>& states,
std::vector<float>* state_scores,
std::vector<std::vector<float>>* stage_scores) final;

static constexpr const char* _type_key = "auto_scheduler.PythonBasedModel";
TVM_DECLARE_FINAL_OBJECT_INFO(PythonBasedModelNode, CostModelNode);
};

/*!
* \brief Managed reference to PythonBasedModelNode.
* \sa PythonBasedModelNode
*/
class PythonBasedModel : public CostModel {
public:
/*!
* \brief The constructor.
* \param update_func The pointer to the update function defined in python
* \param predict_func The pointer to the prediction function defined in python
* \param predict_stage_func The pointer to the prediction function defined in python
*/
PythonBasedModel(PackedFunc update_func, PackedFunc predict_func, PackedFunc predict_stage_func);

TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(PythonBasedModel, CostModel, PythonBasedModelNode);
};

} // namespace auto_scheduler
} // namespace tvm

#endif // TVM_AUTO_SCHEDULER_COST_MODEL_H_
3 changes: 2 additions & 1 deletion python/tvm/auto_scheduler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@
from . import workload_registry

# Shortcut
from .compute_dag import ComputeDAG
from .auto_schedule import SearchTask, TuningOptions, HardwareParams, \
auto_schedule, EmptyPolicy
from .compute_dag import ComputeDAG
from .cost_model import RandomModel
from .measure import MeasureInput, LocalBuilder, LocalRunner, RPCRunner, \
LocalRPCMeasureContext
from .measure_record import RecordToFile, RecordReader, load_best, \
Expand Down
20 changes: 20 additions & 0 deletions python/tvm/auto_scheduler/cost_model/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-import, redefined-builtin
""" Cost model that estimates the performance of programs """

from .cost_model import RandomModel
142 changes: 142 additions & 0 deletions python/tvm/auto_scheduler/cost_model/cost_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

""" Cost model that estimates the performance of programs """
import ctypes
import numpy as np

import tvm._ffi
from tvm.runtime import Object
from .. import _ffi_api


@tvm._ffi.register_object("auto_scheduler.CostModel")
class CostModel(Object):
"""The base class for cost model"""

@tvm._ffi.register_object("auto_scheduler.RandomModel")
class RandomModel(CostModel):
"""A model returns random estimation for all inputs"""
def __init__(self):
self.__init_handle_by_constructor__(_ffi_api.RandomModel)

def update(self, inputs, results):
"""Update the cost model according to new measurement results (training data).
Parameters
----------
inputs : List[MeasureInput]
The measurement inputs
results : List[MeasureResult]
The measurement results
"""
_ffi_api.CostModelUpdate(self, inputs, results)

def predict(self, search_task, states):
"""Predict the scores of states
Parameters
----------
search_task : SearchTask
The search task of states
statse : List[State]
The input states
Returns
-------
scores: List[float]
The predicted scores for all states
"""
return [x.value for x in _ffi_api.CostModelPredict(self, search_task, states)]


@tvm._ffi.register_func("auto_scheduler.cost_model.random_number")
def random_number(n, return_ptr):
""" A random number generator func for c++'s RandomModel """
if n == 0:
return
return_ptr = ctypes.cast(return_ptr, ctypes.POINTER(ctypes.c_float))
array_wrapper = np.ctypeslib.as_array(return_ptr, shape=(n,))
array_wrapper[:] = np.random.uniform(0, 1, (n,))


@tvm._ffi.register_object("auto_scheduler.PythonBasedModel")
class PythonBasedModel(CostModel):
"""Base class for cost models implemented in python"""
def __init__(self):
def update_func(inputs, results):
self.update(inputs, results)

def predict_func(task, states, return_ptr):
return_ptr = ctypes.cast(return_ptr, ctypes.POINTER(ctypes.c_float))
array_wrapper = np.ctypeslib.as_array(return_ptr, shape=(len(states),))
array_wrapper[:] = self.predict(task, states)

def predict_stage_func(task, states, return_ptr):
ret = self.predict_stages(task, states)
return_ptr = ctypes.cast(return_ptr, ctypes.POINTER(ctypes.c_float))
array_wrapper = np.ctypeslib.as_array(return_ptr, shape=ret.shape)
array_wrapper[:] = ret

self.__init_handle_by_constructor__(_ffi_api.PythonBasedModel, update_func,
predict_func, predict_stage_func)

def update(self, inputs, results):
"""Update the cost model according to new measurement results (training data).
Parameters
----------
inputs : List[MeasureInput]
The measurement inputs
results : List[MeasureResult]
The measurement results
"""
raise NotImplementedError

def predict(self, task, states):
"""Predict the scores of states
Parameters
----------
search_task : SearchTask
The search task of states
statse : List[State]
The input states
Returns
-------
scores: List[float]
The predicted scores for all states
"""
raise NotImplementedError

def predict_stages(self, task, states):
"""Predict the scores of states
Parameters
----------
search_task : SearchTask
The search task of states
statse : List[State]
The input states
Returns
-------
scores: List[float]
The predicted scores for all stages in all states in the packed format
"""
raise NotImplementedError
Loading

0 comments on commit 5a49ec7

Please sign in to comment.