diff --git a/include/tvm/auto_scheduler/feature.h b/include/tvm/auto_scheduler/feature.h index 504c2b8ca5a0..cce45356b2a7 100644 --- a/include/tvm/auto_scheduler/feature.h +++ b/include/tvm/auto_scheduler/feature.h @@ -21,7 +21,7 @@ * \file auto_scheduler/feature.h * \brief Feature extraction for the cost model. * We extract one feature vector per BufferStoreNode statement in a TIR Stmt, - * so we call this feature as "Per Store" feature. + * so we call this feature as "per-store" feature. * The cost model also does prediction for each BufferStoreNode statement and aggregates * the predictions as the whole score for a TVM IR (Stmt). * diff --git a/python/tvm/auto_scheduler/__init__.py b/python/tvm/auto_scheduler/__init__.py index bf32cec675f4..8d262b941b3b 100644 --- a/python/tvm/auto_scheduler/__init__.py +++ b/python/tvm/auto_scheduler/__init__.py @@ -29,8 +29,8 @@ from .auto_schedule import SearchTask, TuningOptions, HardwareParams, \ auto_schedule, EmptyPolicy, SketchPolicy from .compute_dag import ComputeDAG -from .cost_model import RandomModel -from .measure import MeasureInput, LocalBuilder, LocalRunner, RPCRunner, \ +from .cost_model import RandomModel, XGBModel +from .measure import MeasureInput, MeasureResult, LocalBuilder, LocalRunner, RPCRunner, \ LocalRPCMeasureContext from .measure_record import RecordToFile, RecordReader, load_best, \ load_records, save_records diff --git a/python/tvm/auto_scheduler/auto_schedule.py b/python/tvm/auto_scheduler/auto_schedule.py index d3b18fe020fb..eb5a3fb49934 100644 --- a/python/tvm/auto_scheduler/auto_schedule.py +++ b/python/tvm/auto_scheduler/auto_schedule.py @@ -161,7 +161,9 @@ def __init__(self, task, schedule_cost_model=RandomModel(), params=None, seed=No seed or random.randint(1, 1 << 30), verbose, init_search_callbacks) def generate_sketches(self, print_for_debug=False): - """ Generate the sketches, this is mainly used for debug. + """ Generate the sketches. + This python interface is mainly used for debugging and testing. + The actual search is all doen in c++. Parameters ---------- @@ -180,6 +182,24 @@ def generate_sketches(self, print_for_debug=False): print(s) return sketches + def sample_initial_population(self, pop_size): + """Sample initial population. + This python interface is mainly used for debugging and testing. + The actual search is all doen in c++. + + Parameters + ---------- + pop_size : int + The size of sampled population + + Returns + ------- + states: List[State] + The sampled states + """ + states = _ffi_api.SketchPolicySampleInitialPopulation(self, pop_size) + return states + @tvm._ffi.register_object("auto_scheduler.TuningOptions") class TuningOptions(Object): """ This controls the options of performance tuning. diff --git a/python/tvm/auto_scheduler/cost_model/__init__.py b/python/tvm/auto_scheduler/cost_model/__init__.py index fc3821cf7998..56e4a5f9128b 100644 --- a/python/tvm/auto_scheduler/cost_model/__init__.py +++ b/python/tvm/auto_scheduler/cost_model/__init__.py @@ -18,3 +18,4 @@ """ Cost model that estimates the performance of programs """ from .cost_model import RandomModel +from .xgb_model import XGBModel diff --git a/python/tvm/auto_scheduler/cost_model/cost_model.py b/python/tvm/auto_scheduler/cost_model/cost_model.py index 699719df6761..80e963f18387 100644 --- a/python/tvm/auto_scheduler/cost_model/cost_model.py +++ b/python/tvm/auto_scheduler/cost_model/cost_model.py @@ -146,5 +146,25 @@ def predict_stages(self, task, states): ------- scores: List[float] The predicted scores for all stages in all states in the packed format + + Note + ---- + For faster data copy between c++ and python, the python part returns scores in a + single flatten array using a packed format. The c++ part then unpacks the flatten array. + + The packed format is: + { + float scores[N]; // scores[i] is the score for states[i]. + int n_stage_0; // the number of stages in states[0] + float stage_scores_0[[n_stage_0] // the scores for all stages in states[0] + int n_stage_1; // the number of stages in states[1] + float stage_scores_1[n_stage_1]; // the scores for all stages in states[1] + ... + int n_stage_i; // the number of stages in states[i] + float stage_scores_1[n_stage_i]; // the scores for all stages in states[i] + ... // untill i == N - 1 + } + To implement this format, we also store int as float, so we can store all numbers + into a single float array. """ raise NotImplementedError diff --git a/python/tvm/auto_scheduler/cost_model/xgb_model.py b/python/tvm/auto_scheduler/cost_model/xgb_model.py new file mode 100644 index 000000000000..6fd8d17259fd --- /dev/null +++ b/python/tvm/auto_scheduler/cost_model/xgb_model.py @@ -0,0 +1,599 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name + +"""Cost model based on xgboost""" +import multiprocessing +import logging +from collections import defaultdict + +import numpy as np +import xgboost as xgb +from xgboost.core import EarlyStopException +from xgboost.callback import _fmt_metric +from xgboost.training import aggcv + +from tvm.autotvm.tuner.metric import max_curve +from .cost_model import PythonBasedModel +from ..feature import get_per_store_features_from_measure_pairs, get_per_store_features_from_states +from ..measure_record import RecordReader + +logger = logging.getLogger('auto_scheduler') + +class XGBDMatrixContext: + """A global context to hold additional attributes of xgb.DMatrix""" + def __init__(self): + self.context_dict = defaultdict(dict) + + def get(self, key, matrix, default=None): + """ + Get an attribute of a xgb.DMatrix + + Parameters + ---------- + key: str + The name of the attribute + matrix: xgb.DMatrix + The matrix + default: Optional[Any] + The default value if the item does not exist + """ + return self.context_dict[key].get(matrix.handle.value, default) + + def set(self, key, matrix, value): + """ + Set an attribute for a xgb.DMatrix + + Parameters + ---------- + key: str + The name of the attribute + matrix: xgb.DMatrix + The matrix + value: Optional[Any] + The new value + """ + self.context_dict[key][matrix.handle.value] = value + +dmatrix_context = XGBDMatrixContext() + + +class XGBModel(PythonBasedModel): + """Train a XGBoost model to predict the normalized throughputs of programs. + + Let the normalized throughput be the score of a program (higher is better). We predict + the (approximiate) score of a program = the sum of the scores of all stages in this program. + i.e. score(P) = score_s0 + score_s1 + ... + score_sn, + where score_si is the score of Stage i in Program P. + + We extract feature for each stage and let the xgboost predict the score for each stage. + We then sum up the predictions as the score of the whole program. + + We use RMSE as the loss function. i.e. loss(P, y) = 1/2 * (score(P) - y)^2, + where P is the program and y is the normalized throughput according to + the ground truth (measurement). + XGBoost does not support this loss function because `score(P)` is a sum of the prediction + of several samples, so we implemented a custom loss function and call it pack-sum-rmse. + It is called "pack-sum" because we combine several samples into a "pack" and sum up + their predictions. + """ + def __init__(self, verbose_eval=25, num_warmup_sample=100, seed=None): + self.xgb_params = { + 'max_depth': 10, + 'gamma': 0.001, + 'min_child_weight': 0, + 'eta': 0.2, + # todo(merrymercy): automatically decrease learning rate when the loss is too large + + 'n_gpus': 0, + 'nthread': multiprocessing.cpu_count() // 2, + 'verbosity': 0, + 'seed': seed or 43, + 'disable_default_eval_metric': 1 + } + self.bst = None + self.plan_size = 32 + self.num_warmup_sample = num_warmup_sample + self.verbose_eval = verbose_eval + + super().__init__() + + # cache measurement input/result pairs and extracted features + self.inputs = [] + self.results = [] + self.inputs_feature_cache = [] + + def update(self, inputs, results): + """Update the cost model according to new measurement results (training data). + XGBoost does not support incremental training, so we re-train a new model every time. + + Parameters + ---------- + inputs : List[MeasureInput] + The measurement inputs + results : List[MeasureResult] + The measurement results + """ + if len(inputs) <= 0: + return + assert len(inputs) == len(results) + + self.inputs.extend(inputs) + self.results.extend(results) + + # extract feature + n_cached = len(self.inputs_feature_cache) + features, normalized_throughputs, task_ids = \ + get_per_store_features_from_measure_pairs(self.inputs, self.results, + skip_first_n_feature_extraction=n_cached) + if n_cached > 0: + features = list(features) + features[:n_cached] = self.inputs_feature_cache + features = np.array(features, dtype=object) + self.inputs_feature_cache = features + dtrain = pack_sum_xgbmatrix(features, normalized_throughputs, + task_ids, normalized_throughputs) + + # train xgb model + self.bst = xgb.train(self.xgb_params, dtrain, + num_boost_round=10000, + obj=pack_sum_square_error, + callbacks=[custom_callback( + stopping_rounds=50, + metric='tr-p-rmse', + fevals=[ + pack_sum_rmse, pack_sum_average_peak_score(self.plan_size), + ], + evals=[(dtrain, 'tr')], + maximize=False, + verbose_eval=self.verbose_eval)]) + + def predict(self, task, states): + """Predict the scores of states + + Parameters + ---------- + search_task : SearchTask + The search task of states + statse : List[State] + The input states + + Returns + ------- + scores: List[float] + The predicted scores for all states + """ + features = get_per_store_features_from_states(states, task) + if self.bst is not None and len(self.inputs) > self.num_warmup_sample: + dtest, pack_ids = feature_to_pack_sum_xgbmatrix(features) + raw_preds = self.bst.predict(dtest) + ret = predict_throughput_pack_sum(raw_preds, pack_ids) + else: + ret = np.random.uniform(0, 1, (len(states),)) + + # Predict 0 for invalid states that failed to be lowered. + for idx, feature in enumerate(features): + if feature.min() == feature.max() == 0: + ret[idx] = float('-inf') + + return ret + + def predict_stages(self, task, states): + """Predict the scores of all stages in states. This is the breakdown version of `predict`. + + Parameters + ---------- + search_task : SearchTask + The search task of states + statse : List[State] + The input states + + Returns + ------- + scores: List[float] + The predicted scores for all stages in all states in the packed format + + Note + ---- + For faster data copy between c++ and python, the python part returns scores in a + single flatten array using a packed format. The c++ part then unpacks the flatten array. + + The packed format is: + { + float scores[N]; // scores[i] is the score for states[i]. + int n_stage_0; // the number of stages in states[0] + float stage_scores_0[[n_stage_0] // the scores for all stages in states[0] + int n_stage_1; // the number of stages in states[1] + float stage_scores_1[n_stage_1]; // the scores for all stages in states[1] + ... + int n_stage_i; // the number of stages in states[i] + float stage_scores_1[n_stage_i]; // the scores for all stages in states[i] + ... // untill i == N - 1 + } + To implement this format, we also store int as float, so we can store all numbers + into a single float array. + """ + features = get_per_store_features_from_states(states, task) + if self.bst is not None and len(self.inputs) > self.num_warmup_sample: + dtest, pack_ids = feature_to_pack_sum_xgbmatrix(features) + raw_preds = self.bst.predict(dtest) + breakdown = predict_throughput_pack_sum(raw_preds, pack_ids) + stage_scores = [[] for _ in range(len(states))] + for pred, pack_id in zip(raw_preds, pack_ids): + stage_scores[pack_id].append(pred) + for idx, stage_score in enumerate(stage_scores): + breakdown = np.append(breakdown, len(stage_score)) + breakdown = np.concatenate((breakdown, np.array(stage_score))) + else: + breakdown = np.concatenate( + (np.random.uniform(0, 1, (len(states), )), np.zeros(len(states), ))) + + # Predict 0 for invalid states that failed to be lowered. + for idx, feature in enumerate(features): + if feature.min() == feature.max() == 0: + breakdown[idx] = float('-inf') + + return breakdown + + def update_from_file(self, file_name, n_lines=None): + """Load measure records from a log file to update the cost model. + This function can be used to pre-train the cost model with history log files. + + Parameters + ---------- + file_name: str + The filename + n_lines: Optional[int] + Only load first n lines of the log file + """ + inputs, results = RecordReader(file_name).read_lines(n_lines) + logger.info("XGBModel: Loaded %s measurement records from %s", len(inputs), file_name) + self.update(inputs, results) + + def save(self, file_name: str): + """Save the model to a file + + Parameters + ---------- + file_name: str + The filename + """ + self.bst.save_model(file_name) + + def load(self, file_name: str): + """Load the model from a file + + Parameters + ---------- + file_name: str + The filename + """ + if self.bst is None: + self.bst = xgb.Booster(self.xgb_params) + self.bst.load_model(file_name) + self.num_warmup_sample = -1 + + +def feature_to_pack_sum_xgbmatrix(xs): + """Convert an extracted multi-stage feature vector to a xgbmatrx in pack-sum format + + Parameters + ---------- + xs: np.ndarray + The feature vector + + Returns + ------- + dmatrix: xgb.DMatrix + The DMatrix + pack_ids: List[int] + pack ids information + """ + x_flatten = [] + pack_ids = [] + + for ct, x in enumerate(xs): + for row in x: + x_flatten.append(row) + pack_ids.append(ct) + + return xgb.DMatrix(np.array(x_flatten)), pack_ids + + +def pack_sum_xgbmatrix(xs, ys, gids=None, weights=None): + """Convert (feature, label) pairs into a xgb matrix with pack-sum format + + Parameters + ---------- + xs: np.ndarray + The feature vector + ys: np.ndarray + The normaizlied throughput + gids: Optional[List[int]] + Group id (task id) + weights: Optional[np.ndarray] + The weight of samples + + Returns + ------- + dmatrix: xgb.DMatrix + The DMatrix with pack-sum information + """ + if gids is not None: + # sort by group + indices = gids.argsort() + xs, ys = xs[indices], ys[indices] + group_sizes = np.bincount(gids) + if weights is not None: + weights = weights[indices] + else: + # assume it has only one group + group_sizes = [len(xs)] + + x_flatten = [] + y_flatten = [] + weights_flatten = [] + pack_ids = [] + + if weights is not None: + for ct, (x, y, w) in enumerate(zip(xs, ys, weights)): + for row in x: + x_flatten.append(row) + y_flatten.append(y) + weights_flatten.append(w) + pack_ids.append(ct) + else: + for ct, (x, y) in enumerate(zip(xs, ys)): + for row in x: + x_flatten.append(row) + y_flatten.append(y) + pack_ids.append(ct) + + ret = xgb.DMatrix(np.array(x_flatten), y_flatten) + if weights is not None: + ret.set_weight(weights_flatten) + dmatrix_context.set('pack_ids', ret, np.array(pack_ids)) + dmatrix_context.set('group_sizes', ret, group_sizes) + return ret + + +def predict_throughput_pack_sum(raw_preds, pack_ids): + """Predict the throughputs for predictions in pack-sum format + + Parameters + ---------- + raw_preds: np.ndarray + The raw predictions + pack_ids: List[int] + The pack id for predictions + + Returns + ------- + throughputs: np.ndarray + The throughput + """ + sum_pred = np.bincount(pack_ids, weights=raw_preds) + return sum_pred + +def pack_sum_square_error(preds, dtrain): + """Implement square error loss on pack-sum format as + a custom objective function for xgboost. + + Parameters + ---------- + preds: np.ndarray + The predicitons + dtrain: xgb.DMatrix + The training set + + Returns + ------- + gradient: np.ndarray + hessian: np.ndarray + gradient and hessian according to the xgboost format + """ + pack_ids = dmatrix_context.get("pack_ids", dtrain) + weight = dtrain.get_weight() + + sum_pred = np.bincount(pack_ids, weights=preds) + x = sum_pred[pack_ids] + y = dtrain.get_label() + gradient = x - y + hessian = np.ones_like(gradient) + + if len(weight) == 0: + return gradient, hessian + + return gradient * weight, hessian * weight + +def pack_sum_rmse(raw_preds, labels): + """Evaluate RMSE (rooted mean square error) in the pack-sum format + + Parameters + ---------- + raw_preds: np.ndarray + The raw prediction + labels: xgb.DMatrix + The groud-truth label matrix + + Returns + ------- + name: str + score: float + The name and score of this metric + """ + pack_ids = dmatrix_context.get("pack_ids", labels) + preds = predict_throughput_pack_sum(raw_preds, pack_ids)[pack_ids] + return 'p-rmse', np.sqrt(np.mean(np.square((preds - labels.get_label())))) + +def pack_sum_average_peak_score(N): + """Return the evaluation function for average-peak-score@N + + Parameters + ---------- + N: int + The "N" in "average-peak-score@N" + + Returns + ------- + The evaluation function + """ + + def feval(preds, labels): + """Evaluate average-peak-score@N in the pack-sum format + + Parameters + ---------- + raw_preds: np.ndarray + The raw prediction + labels: xgb.DMatrix + The groud-truth label matrix + + Returns + ------- + name: str + score: float + The name and score of this metric + """ + group_sizes = dmatrix_context.get('group_sizes', labels, [len(preds)]) + pack_ids = dmatrix_context.get("pack_ids", labels) + + preds = predict_throughput_pack_sum(preds, pack_ids) + labels = (np.bincount(pack_ids, weights=labels.get_label()) + / np.unique(pack_ids, return_counts=True)[1]) + + scores = [] + offset = 0 + for size in group_sizes: + preds_group = preds[offset:offset + size] + labels_group = labels[offset:offset + size] + offset += size + + trials = np.argsort(preds_group)[::-1][:N] + trial_scores = labels_group[trials] + curve = max_curve(trial_scores) / np.max(labels_group) + scores.append(np.mean(curve)) + return "a-peak@%d" % N, np.mean(scores) + return feval + + +def custom_callback(stopping_rounds, metric, fevals, evals=(), log_file=None, + maximize=False, verbose_eval=True, skip_every=2): + """Callback function for xgboost to support multiple custom evaluation functions""" + state = {} + metric_shortname = metric.split("-")[1] + + def init(env): + """internal function""" + bst = env.model + + state['maximize_score'] = maximize + state['best_iteration'] = 0 + if maximize: + state['best_score'] = float('-inf') + else: + state['best_score'] = float('inf') + + if bst is not None: + if bst.attr('best_score') is not None: + state['best_score'] = float(bst.attr('best_score')) + state['best_iteration'] = int(bst.attr('best_iteration')) + state['best_msg'] = bst.attr('best_msg') + else: + bst.set_attr(best_iteration=str(state['best_iteration'])) + bst.set_attr(best_score=str(state['best_score'])) + else: + assert env.cvfolds is not None + + def callback(env): + """internal function""" + if not state: + init(env) + + bst = env.model + i = env.iteration + cvfolds = env.cvfolds + + res_dict = {} + + if i % skip_every == 1: + return + + ##### evaluation ##### + if cvfolds is not None: + for feval in fevals: + tmp = aggcv([f.eval(i, feval) for f in cvfolds]) + for k, mean, std in tmp: + res_dict[k] = [mean, std] + else: + for feval in fevals: + bst_eval = bst.eval_set(evals, i, feval) + res = [x.split(':') for x in bst_eval.split()] + for kv in res[1:]: + res_dict[kv[0]] = [float(kv[1])] + + eval_res = [] + keys = list(res_dict.keys()) + keys.sort(key=lambda x: x if metric_shortname not in x else "a" + x) + for key in keys: + v = res_dict[key] + eval_res.append([key] + v) + + ##### print eval result ##### + if not isinstance(verbose_eval, bool) and verbose_eval and i % verbose_eval == 0: + infos = ["XGB iter: %3d" % i] + for item in eval_res: + if 'null' in item[0]: + continue + infos.append("%s: %.6f" % (item[0], item[1])) + + logger.debug("\t".join(infos)) + if log_file: + with open(log_file, "a") as fout: + fout.write("\t".join(infos) + '\n') + + ##### choose score and do early stopping ##### + score = None + for item in eval_res: + if item[0] == metric: + score = item[1] + break + assert score is not None + + best_score = state['best_score'] + best_iteration = state['best_iteration'] + maximize_score = state['maximize_score'] + if (maximize_score and score > best_score) or \ + (not maximize_score and score < best_score): + msg = '[%d] %s' % ( + env.iteration, + '\t'.join([_fmt_metric(x) for x in eval_res])) + state['best_msg'] = msg + state['best_score'] = score + state['best_iteration'] = env.iteration + # save the property to attributes, so they will occur in checkpoint. + if env.model is not None: + env.model.set_attr(best_score=str(state['best_score']), + best_iteration=str(state['best_iteration']), + best_msg=state['best_msg']) + elif env.iteration - best_iteration >= stopping_rounds: + best_msg = state['best_msg'] + if verbose_eval and env.rank == 0: + logger.debug("XGB stopped. Best iteration: %s ", best_msg) + raise EarlyStopException(best_iteration) + + return callback diff --git a/python/tvm/auto_scheduler/feature.py b/python/tvm/auto_scheduler/feature.py index 3ed87f3c2b9a..e531c3d46214 100644 --- a/python/tvm/auto_scheduler/feature.py +++ b/python/tvm/auto_scheduler/feature.py @@ -19,7 +19,7 @@ Python API for Feature extraction. The extracted features vector are used by cost models. We extract one feature vector per BufferStoreNode statement in a TIR Stmt, -so we call this feature as "Per Store" feature. +so we call this feature as "per-store" feature. The cost model also does prediction for each BufferStoreNode statement and aggregates the predicted score of each BufferStoreNode as the score of a TIR Stmt. @@ -61,22 +61,30 @@ def unpack_feature(byte_arr: bytearray) -> Tuple[np.ndarray, np.ndarray, np.ndar Normalized throughputs task_ids: np.ndarray Task ids - """ - # The format for n records is: - # { - # int n; - # int[n+2] sizes + Note + ---- + For faster data copy between c++ and python, the c++ part returns features in a single + flatten array using a packed format. The python part then unpacks the flatten array. + + The packed format for n records is: + { + int n; + int sizes[n+2]; // The sizes for the following arrays - # float[sizes[0]] feature for record 1 - # float[sizes[1]] feature for record 2 - # ... feature for record i... - # float[sizes[n-1]] feature for record n + float features_0[size[0]]; // The features for record 0 + float features_1[size[1]]; // The features for record 1 + ... + float features_i[size[i]]; // The features for record i + ... // until i == n - 1 - # float[sizes[n]] normalized throughput for n records - # int[sizes[n+1]] task id for n records - # } + float throuputs[sizes[n]]; // The normalized throughputs for n records + int task_ids[size[n+1]; // The task ids for n records + } + To implement this format, we also store int as float, so we can store all numbers + into a single float array. + """ vec_len = DEFAULT_FEATURE_VEC_LEN # unpack sizes @@ -95,8 +103,8 @@ def unpack_feature(byte_arr: bytearray) -> Tuple[np.ndarray, np.ndarray, np.ndar # Now, we need to unpack the feature for multiple statements. # The format is: # { - # int n_stmts - # float[n_stmt][vec_len] feature_vecs + # int n_stage; // The number of stages + # float feature_vecs[n_stage][vec_len] // The feature vector for each stage # } # where vec_len can be calculated by `(size - 1) / n_stmts` @@ -137,7 +145,7 @@ def get_per_store_features_from_file(filename: str, max_lines: int, max_n_bufs: Optional[int] = None) \ -> Tuple[np.ndarray, np.ndarray, np.ndarray]: - """Get per_store features from a log file + """Get per-store features from a log file Parameters ---------- @@ -167,7 +175,7 @@ def get_per_store_features_from_measure_pairs(inputs: List[MeasureInput], skip_first_n_feature_extraction: int = 0, max_n_bufs: Optional[int] = None) \ -> Tuple[np.ndarray, np.ndarray, np.ndarray]: - """Get per_store features from measurement input/result pairs + """Get per-store features from measurement input/result pairs Parameters ---------- @@ -197,7 +205,7 @@ def get_per_store_features_from_measure_pairs(inputs: List[MeasureInput], def get_per_store_features_from_states(states: List[Union[State, StateObject]], task: "SearchTask", max_n_bufs: Optional[int] = None) -> List[np.ndarray]: - """Get per_store features from measurement input/result pairs + """Get per-store features from measurement input/result pairs Parameters ---------- diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py index 4d626f799462..925de2f871e6 100644 --- a/python/tvm/auto_scheduler/measure.py +++ b/python/tvm/auto_scheduler/measure.py @@ -48,6 +48,7 @@ from tvm.contrib import tar, ndk from . import _ffi_api +from .loop_state import StateObject from .utils import get_const_tuple, NoDaemonPool, call_func_with_timeout, request_remote, \ check_remote @@ -71,12 +72,13 @@ class MeasureInput(Object): Parameters ---------- task : SearchTask - The SearchTask of this measure. - state : State + The SearchTask of this measurement. + state : Union[State, StateObject] The State to be measured. """ def __init__(self, task, state): - self.__init_handle_by_constructor__(_ffi_api.MeasureInput, task, state.state_object) + state = state if isinstance(state, StateObject) else state.state_object + self.__init_handle_by_constructor__(_ffi_api.MeasureInput, task, state) @tvm._ffi.register_object("auto_scheduler.BuildResult") diff --git a/src/auto_scheduler/cost_model.cc b/src/auto_scheduler/cost_model.cc index 68c1d5c1f118..456e2ef3cc51 100644 --- a/src/auto_scheduler/cost_model.cc +++ b/src/auto_scheduler/cost_model.cc @@ -78,6 +78,25 @@ void PythonBasedModelNode::PredictStages(const SearchTask& task, const Array(flatten_scores.data())); + /* For faster data copy between c++ and python, the python part returns scores in a + * single flatten array using a packed format. The c++ part then unpacks the flatten array. + * + * The packed format is: + * { + * float scores[N]; // scores[i] is the score for states[i]. + * int n_stage_0; // the number of stages in states[0] + * float stage_scores_0[[n_stage_0] // the scores for all stages in states[0] + * int n_stage_1; // the number of stages in states[1] + * float stage_scores_1[n_stage_1]; // the scores for all stages in states[1] + * ... + * int n_stage_i; // the number of stages in states[i] + * float stage_scores_1[n_stage_i]; // the scores for all stages in states[i] + * ... // untill i == N - 1 + * } + * To implement this format, we also store int as float, so we can store all numbers + * into a single float array. + */ + // Unpack flatten scores. state_scores->clear(); stage_scores->clear(); diff --git a/src/auto_scheduler/feature.cc b/src/auto_scheduler/feature.cc index 2f89750919ba..bbef387d3f72 100644 --- a/src/auto_scheduler/feature.cc +++ b/src/auto_scheduler/feature.cc @@ -1496,21 +1496,27 @@ void GetPerStoreFeaturesFromMeasurePairs(const Array& inputs, /* * \brief Serialize a two-dimensional variable-size feature vector with normalized throughputs * and task ids to a one-dimensional flatten byte array. - * We have to serialize it for faster transmission speed when copying it to python. - * This flatten array will be deserialized in python. * - * serialization format for n records: + * For faster data copy between c++ and python, the c++ part returns features in a single + * flatten array using a packed format. The python part then unpacks the flatten array. * - * int n; - * int[n+2] sizes + * The packed format for n records is: + * { + * int n; + * int sizes[n+2]; // The sizes for the following arrays * - * float[sizes[0]] feature for record 1 - * float[sizes[1]] feature for record 2 - * ... feature for record i... - * float[sizes[n-1]] feature for record n + * float features_0[size[0]]; // The features for record 0 + * float features_1[size[1]]; // The features for record 1 + * ... + * float features_i[size[i]]; // The features for record i + * ... // until i == n - 1 * - * float[sizes[n]] normalized throughput for n records - * int[sizes[n+1]] task id for n records + * float throuputs[sizes[n]]; // The normalized throughputs for n records + * int task_ids[size[n+1]; // The task ids for n records + * + * } + * To implement this format, we also store int as float, so we can store all numbers + * into a single float array. */ TVMByteArray SerializeFeatures(std::vector>&& features, std::vector&& normalized_throughputs, diff --git a/src/auto_scheduler/search_policy/sketch_policy.cc b/src/auto_scheduler/search_policy/sketch_policy.cc index 450c429f95c8..c428cf71af17 100644 --- a/src/auto_scheduler/search_policy/sketch_policy.cc +++ b/src/auto_scheduler/search_policy/sketch_policy.cc @@ -196,7 +196,7 @@ Array SketchPolicyNode::SearchOneRound(int num_random_states, Array& sketches = GenerateSketches(); // 2. Sample the init population - Array init_populations = SampleInitPopulation( + Array init_population = SampleInitPopulation( sketches, is_cost_model_reasonable ? population - num_use_measured : population); // 3. If the cost model is useless (i.e. RandomCostModel), just random pick some generated @@ -205,13 +205,13 @@ Array SketchPolicyNode::SearchOneRound(int num_random_states, Array indices = Argsort(measured_states_throughputs_); for (int i = 0; i < num_use_measured; i++) { - init_populations.push_back(measured_states_vector_[indices[i]]); + init_population.push_back(measured_states_vector_[indices[i]]); } // Sample some random states for eps-greedy - *random_states = RandomSampleStates(init_populations, &rand_gen, num_random_states * 10); - return EvolutionarySearch(init_populations, num_measure_per_iter_ * 2); + *random_states = RandomSampleStates(init_population, &rand_gen, num_random_states * 10); + return EvolutionarySearch(init_population, num_measure_per_iter_ * 2); } else { - return RandomSampleStates(init_populations, &rand_gen, num_measure_per_iter_ * 3); + return RandomSampleStates(init_population, &rand_gen, num_measure_per_iter_ * 3); } } @@ -322,7 +322,7 @@ Array SketchPolicyNode::SampleInitPopulation(const Array& sketches return out_states; } -Array SketchPolicyNode::EvolutionarySearch(const Array& init_populations, +Array SketchPolicyNode::EvolutionarySearch(const Array& init_population, int out_size) { Array best_states; auto tic_begin = std::chrono::high_resolution_clock::now(); @@ -397,5 +397,13 @@ TVM_REGISTER_GLOBAL("auto_scheduler.SketchPolicy") TVM_REGISTER_GLOBAL("auto_scheduler.SketchPolicyGenerateSketches") .set_body_typed([](SketchPolicy policy) { return policy->GenerateSketches(); }); +TVM_REGISTER_GLOBAL("auto_scheduler.SketchPolicySampleInitialPopulation") + .set_body_typed([](SketchPolicy policy, int pop_size) { + const Array& sketches = policy->GenerateSketches(); + + Array init_population = policy->SampleInitPopulation(sketches, pop_size); + return init_population; + }); + } // namespace auto_scheduler } // namespace tvm diff --git a/src/auto_scheduler/search_policy/sketch_policy.h b/src/auto_scheduler/search_policy/sketch_policy.h index 288de839a1b7..104e51e8ab25 100644 --- a/src/auto_scheduler/search_policy/sketch_policy.h +++ b/src/auto_scheduler/search_policy/sketch_policy.h @@ -83,7 +83,7 @@ class SketchPolicyNode : public SearchPolicyNode { public: /*! \brief The cost model to estimate the complete schedules. */ CostModel schedule_cost_model; - /*! \brief The parameters map for this search process. */ + /*! \brief The parameters map for this search policy. */ Map params; /*! \brief The rules to generate sketches. */ std::vector sketch_rules; @@ -103,6 +103,14 @@ class SketchPolicyNode : public SearchPolicyNode { */ Array GenerateSketches(); + /*! + * \brief Sample the init population. + * \param sketches The initial sketches for the sampled population + * \param out_size The number of output states. + * \return The generated states (the initial population). + */ + Array SampleInitPopulation(const Array& sketches, int out_size); + static constexpr const char* _type_key = "auto_scheduler.SketchPolicy"; TVM_DECLARE_FINAL_OBJECT_INFO(SketchPolicyNode, SearchPolicyNode); @@ -117,14 +125,6 @@ class SketchPolicyNode : public SearchPolicyNode { */ Array SearchOneRound(int num_random_states, Array* random_states = nullptr); - /*! - * \brief Sample init population. - * \param sketches The initial sketches to process population. - * \param out_size The number of expected output states. - * \return The generated states after initial population. - */ - Array SampleInitPopulation(const Array& sketches, int out_size); - /*! * \brief Perform evolutionary search. * \param init_populations The states generated from init population. diff --git a/tests/python/unittest/test_auto_scheduler_cost_model.py b/tests/python/unittest/test_auto_scheduler_cost_model.py index 3d152a4f704d..20449035517b 100644 --- a/tests/python/unittest/test_auto_scheduler_cost_model.py +++ b/tests/python/unittest/test_auto_scheduler_cost_model.py @@ -17,26 +17,65 @@ """Test cost models""" +import tempfile + +import numpy as np + import tvm from tvm import auto_scheduler from test_auto_scheduler_common import matmul_auto_scheduler_test -def test_random_model(): - if not tvm.runtime.enabled("llvm"): - return +def get_sample_records(number): + """Generate random a list of random MeasureInput and MeasureResult pairs""" N = 128 workload_key = auto_scheduler.make_workload_key(matmul_auto_scheduler_test, (N, N, N)) dag = auto_scheduler.ComputeDAG(workload_key) target = tvm.target.create('llvm') task = auto_scheduler.SearchTask(dag, workload_key, target) + policy = auto_scheduler.SketchPolicy(task, verbose=0) + states = policy.sample_initial_population(number) + + inputs = [auto_scheduler.MeasureInput(task, s) for s in states] + results = [auto_scheduler.MeasureResult([np.random.uniform(0.5, 1.0)], 0, "", 0.1, 0) + for _ in range(len(inputs))] + + return task, dag, inputs, results + + +def test_random_model(): + task, dag, inputs, results = get_sample_records(50) model = auto_scheduler.RandomModel() - model.update([], []) - scores = model.predict(task, [dag.init_state, dag.init_state]) - assert len(scores) == 2 + model.update(inputs, results) + scores = model.predict(task, [x.state for x in inputs]) + assert len(scores) == len(inputs) + + +def test_xgb_model(): + task, dag, inputs, results = get_sample_records(50) + + model = auto_scheduler.XGBModel(num_warmup_sample=-1) + model.update(inputs, results) + preds = model.predict(task, [x.state for x in inputs]) + assert len(preds) == len(inputs) + + costs = [np.mean([x.value for x in res.costs]) for res in results] + throughputs = np.min(costs) / costs + + rmse = np.sqrt(np.mean([np.square(pred - label) for pred, label in zip(preds, throughputs)])) + assert rmse <= 0.05 + + with tempfile.NamedTemporaryFile() as fp: + auto_scheduler.save_records(fp.name, inputs, results) + model.update_from_file(fp.name) + + with tempfile.NamedTemporaryFile() as fp: + model.save(fp.name) + model.load(fp.name) if __name__ == "__main__": test_random_model() + test_xgb_model()