From 56b01870f6c482c44a81d7a656757919f8989dd7 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Tue, 21 Jul 2020 22:25:01 -0700 Subject: [PATCH] move header files and polish comments --- .../tvm}/auto_scheduler/auto_schedule.h | 28 +++++----- .../tvm}/auto_scheduler/compute_dag.h | 55 ++++++++++--------- .../tvm}/auto_scheduler/loop_state.h | 36 ++++++------ {src => include/tvm}/auto_scheduler/measure.h | 34 ++++++------ .../tvm}/auto_scheduler/measure_record.h | 28 +++++----- .../tvm/auto_scheduler}/search_policy.h | 37 ++++++------- .../tvm}/auto_scheduler/search_task.h | 5 +- .../tvm}/auto_scheduler/transform_step.h | 13 ++--- python/tvm/auto_scheduler/auto_schedule.py | 5 +- .../tvm/auto_scheduler/workload_registry.py | 2 +- src/auto_scheduler/auto_schedule.cc | 3 +- src/auto_scheduler/compute_dag.cc | 5 +- src/auto_scheduler/loop_state.cc | 5 +- src/auto_scheduler/measure.cc | 3 +- src/auto_scheduler/measure_record.cc | 7 +-- .../search_policy/empty_policy.cc | 3 +- .../search_policy/empty_policy.h | 4 +- .../search_policy/search_policy.cc | 3 +- src/auto_scheduler/search_task.cc | 3 +- src/auto_scheduler/transform_step.cc | 16 +++--- tests/cpp/auto_scheduler_test.cc | 6 +- 21 files changed, 144 insertions(+), 157 deletions(-) rename {src => include/tvm}/auto_scheduler/auto_schedule.h (81%) rename {src => include/tvm}/auto_scheduler/compute_dag.h (83%) rename {src => include/tvm}/auto_scheduler/loop_state.h (96%) rename {src => include/tvm}/auto_scheduler/measure.h (93%) rename {src => include/tvm}/auto_scheduler/measure_record.h (83%) rename {src/auto_scheduler/search_policy => include/tvm/auto_scheduler}/search_policy.h (79%) rename {src => include/tvm}/auto_scheduler/search_task.h (97%) rename {src => include/tvm}/auto_scheduler/transform_step.h (98%) diff --git a/src/auto_scheduler/auto_schedule.h b/include/tvm/auto_scheduler/auto_schedule.h similarity index 81% rename from src/auto_scheduler/auto_schedule.h rename to include/tvm/auto_scheduler/auto_schedule.h index 55c6992dfd4e7..8477966c02477 100644 --- a/src/auto_scheduler/auto_schedule.h +++ b/include/tvm/auto_scheduler/auto_schedule.h @@ -18,19 +18,17 @@ */ /*! - * \file auto_scheduler/auto_schedule.h - * \brief The user interface of the TVM Auto-scheduler. This is the entry structure to get - * schedule search requirements from upper level (Python API), and returns a high performance - * schedule after search process. + * \file tvm/auto_scheduler/auto_schedule.h + * \brief The user interface of the auto scheduler. */ #ifndef TVM_AUTO_SCHEDULER_AUTO_SCHEDULE_H_ #define TVM_AUTO_SCHEDULER_AUTO_SCHEDULE_H_ -#include +#include +#include -#include "measure.h" -#include "search_policy/search_policy.h" +#include namespace tvm { namespace auto_scheduler { @@ -38,9 +36,9 @@ namespace auto_scheduler { /*! \brief Tuning and measurement options. */ class TuningOptionsNode : public Object { public: - /*! \brief Number of total measurement trials. */ + /*! \brief The number of total measurement trials. */ int num_measure_trials; - /*! \brief Stops early the tuning if no improvement after n measurements. */ + /*! \brief Stops the tuning early if no improvement after n measurements. */ int early_stopping; /*! \brief The number of programs to be measured at each search round. */ int num_measures_per_round; @@ -51,7 +49,7 @@ class TuningOptionsNode : public Object { int verbose; /*! \brief ProgramBuilder which builds the program */ ProgramBuilder builder; - /*! \brief ProgramRunner which runs the program and measure time costs */ + /*! \brief ProgramRunner which runs the program and measures time costs */ ProgramRunner runner; /*! \brief MeasureCallback functions to be called after each measure batch */ Optional> measure_callbacks; @@ -81,8 +79,8 @@ class TuningOptions : public ObjectRef { public: /*! * \brief The constructor - * \param num_measure_trials Number of total measurement trials. - * \param early_stopping Stops early the tuning if no improvement after n measurements. + * \param num_measure_trials The number of total measurement trials. + * \param early_stopping Stops the tuning early if no improvement after n measurements. * \param num_measures_per_round The number of programs to be measured at each search round. * \param verbose Verbosity level. 0 for silent, 1 to output information during schedule * search. @@ -100,11 +98,11 @@ class TuningOptions : public ObjectRef { }; /*! - * \brief Auto schedule search for a given compute declaration. + * \brief Run schedule search for a given compute declaration. * \param task The search task of the compute declaration. - * \param search_policy The search policy to be used for schedule search. + * \param search_policy The search policy to be used. * \param tuning_options Tuning and measurement options. - * \return A `te::schedule` and the a Array of `te::Tensor` to be used in `tvm.lower` or + * \return A `te::schedule` and the an Array of `te::Tensor` to be used in `tvm.lower` or * `tvm.build`. */ TVM_DLL std::pair> AutoSchedule(SearchTask task, diff --git a/src/auto_scheduler/compute_dag.h b/include/tvm/auto_scheduler/compute_dag.h similarity index 83% rename from src/auto_scheduler/compute_dag.h rename to include/tvm/auto_scheduler/compute_dag.h index 6e272d40e9301..12d7f362c64db 100644 --- a/src/auto_scheduler/compute_dag.h +++ b/include/tvm/auto_scheduler/compute_dag.h @@ -1,4 +1,4 @@ -/* +/*r * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information @@ -18,8 +18,8 @@ */ /*! - * \file auto_scheduler/compute_dag.h - * \brief The TVM Auto-scheduler computational graph and related program analyses. + * \file tvm/auto_scheduler/compute_dag.h + * \brief The auto-scheduler's computational graph and related program analyses. * * We convert a compute declaration described by `tvm.compute` (could be a single operator or a * subgraph) to a ComputeDAG. It keeps the input/output tensors of the compute declaration, @@ -35,6 +35,8 @@ #ifndef TVM_AUTO_SCHEDULER_COMPUTE_DAG_H_ #define TVM_AUTO_SCHEDULER_COMPUTE_DAG_H_ +#include +#include #include #include @@ -42,8 +44,6 @@ #include #include -#include "loop_state.h" - namespace tvm { namespace auto_scheduler { @@ -89,25 +89,25 @@ class AccessAnalyzer : public ObjectRef { * \brief Return whether this operation needs multi-level tiling * \param op The operation */ - bool NeedsMultiLevelTiling(const te::Operation& op) const; + TVM_DLL bool NeedsMultiLevelTiling(const te::Operation& op) const; /*! * \brief Return whether this operation is an injective operation * \param op The operation */ - bool IsInjective(const te::Operation& op) const; + TVM_DLL bool IsInjective(const te::Operation& op) const; /*! * \brief Return whether this operation is strictly inlinable * \param op The operation */ - bool IsStrictInlineable(const te::Operation& op) const; + TVM_DLL bool IsStrictInlineable(const te::Operation& op) const; /*! * \brief Return whether this operation is an output op * \param op The operation */ - bool IsOutput(const te::Operation& op) const; + TVM_DLL bool IsOutput(const te::Operation& op) const; /*! * \brief Get all consumers of on operation @@ -116,8 +116,10 @@ class AccessAnalyzer : public ObjectRef { * \param consumers The return consumer set * \note This function propagates the relation for inlined ops */ - void GetConsumers(const State& state, const te::Operation& op, - std::unordered_set* consumers) const; + TVM_DLL void GetConsumers( + const State& state, + const te::Operation& op, + std::unordered_set* consumers) const; /*! * \brief Get all producers of on operation @@ -126,8 +128,10 @@ class AccessAnalyzer : public ObjectRef { * \param producers The return producer set * \note This function propagates the relation for inlined ops */ - void GetProducers(const State& state, const te::Operation& op, - std::unordered_set* producers) const; + TVM_DLL void GetProducers( + const State& state, + const te::Operation& op, + std::unordered_set* producers) const; /*! * \brief Get all direct producers of on operation @@ -167,11 +171,11 @@ class ComputeDAGNode : public Object { Array tensors; /*! \brief All related operations in topo order. */ Array ops; - /*! \brief Number of total float operations for this ComputeDAG. */ + /*! \brief The number of total float operations for this ComputeDAG. */ double flop_ct; /*! \brief The initial state without any transform steps. */ State init_state; - /*! \brief Static read-write access analyzer */ + /*! \brief The static read-write access analyzer */ AccessAnalyzer access_analyzer; void VisitAttrs(tvm::AttrVisitor* v) { @@ -194,16 +198,17 @@ class ComputeDAG : public ObjectRef { /*! \brief The constructor. * \param tensors `te::Tensor`s for a compute declaration. */ - explicit ComputeDAG(Array tensors); + TVM_DLL explicit ComputeDAG(Array tensors); /*! - * \brief Apply the history transform steps from a State to get a TVM schedule. + * \brief Apply the history transform steps to get a TVM schedule. * \param transform_steps Transform steps of a state. - * \param stages A pointer to a `te::Stage` Array, default to be nullptr. - * Pass a valid pointer if these information needs to be used outside this function. - * \param stage_to_axes A pointer to a StageToAxesMap, default to be nullptr. - * Pass a valid pointer if these information needs to be used outside this function. - * \return A `te.schedule` and the a list of `te.Tensor` to be used in `tvm.lower` or `tvm.build`. + * \param stages The list of stages after applying the steps. + * Pass a valid pointer if this information needs to be used outside this function. + * \param stage_to_axes The map that stores all axes for one stage. + * Pass a valid pointer if this information needs to be used outside this function. + * \return A `te.schedule` and the an Array of `te.Tensor` to be used in `tvm.lower` + * or `tvm.build`. */ std::pair> ApplySteps( const Array& transform_steps, Array* stages = nullptr, @@ -222,9 +227,9 @@ class ComputeDAG : public ObjectRef { * The states can lose complete bound information after some transform steps (e.g., compute_at). * We can call this function to infer and fill all the bound information. * This function calls TVM InferBound pass internally to get the bound. - * The returned state of this function is guaranteed to have complete iterator extent information. - * \param state The state to. - * \return The State after inferbound. + * The returned state of this function is guaranteed to have complete bound information. + * \param state The input state. + * \return The State with complete bound information */ State InferBound(const State& state) const; diff --git a/src/auto_scheduler/loop_state.h b/include/tvm/auto_scheduler/loop_state.h similarity index 96% rename from src/auto_scheduler/loop_state.h rename to include/tvm/auto_scheduler/loop_state.h index 4d6477b92b0fe..ab7a52081b938 100644 --- a/src/auto_scheduler/loop_state.h +++ b/include/tvm/auto_scheduler/loop_state.h @@ -48,6 +48,8 @@ #ifndef TVM_AUTO_SCHEDULER_LOOP_STATE_H_ #define TVM_AUTO_SCHEDULER_LOOP_STATE_H_ +#include +#include #include #include @@ -55,8 +57,6 @@ #include #include -#include "transform_step.h" - namespace tvm { namespace auto_scheduler { @@ -159,10 +159,16 @@ using IterKey = std::pair; */ class AttachMapNode : public Object { public: + struct key_hash : public std::function { + std::size_t operator()(const IterKey& k) const { + return ::dmlc::HashCombine(std::hash()(k.first), std::hash()(k.second)); + } + }; + /*! \brief A Map to store the mapping of stage to its attached iterator. */ std::unordered_map stage_to_attach_iter; /*! \brief A Map to store the mapping of iterator to the stage attached to it. */ - std::unordered_map> iter_to_attached_stages; + std::unordered_map, key_hash> iter_to_attached_stages; static constexpr const char* _type_key = "auto_scheduler.AttachMap"; TVM_DECLARE_FINAL_OBJECT_INFO(AttachMapNode, Object); @@ -381,21 +387,11 @@ class State : public ObjectRef { // Hash and equal function for State namespace std { -/*! \brief The hash function for auto_scheduler::State. */ -template <> -struct hash<::tvm::auto_scheduler::State> { - std::size_t operator()(const ::tvm::auto_scheduler::State& state) const { - return tvm::runtime::ObjectHash()(state.ToStr()); - } -}; - /*! * \brief The equal_to function for auto_scheduler::State. - * We use the schedule result(its string format) of a state to check if two states are `euqal`. - * Equal States: 1. the transform steps are totally the same; 2. even with different steps, two - * states may still result in a same schedule. e.g. To split a axis with extent 512 to 3 parts - * [8, 16, 4]. We can split from inner to outter by factors [16, 4], while we can get a same result - * to split from outter to inner by factors [8, 16]) + * This function checkes the equality by looking at the lowered string format of states. + * If two states with different transform history have the same lowered string format, + * they will be considered being equal. */ template <> struct equal_to<::tvm::auto_scheduler::State> { @@ -405,6 +401,14 @@ struct equal_to<::tvm::auto_scheduler::State> { } }; +/*! \brief The hash function for auto_scheduler::State. */ +template <> +struct hash<::tvm::auto_scheduler::State> { + std::size_t operator()(const ::tvm::auto_scheduler::State& state) const { + return tvm::runtime::ObjectHash()(state.ToStr()); + } +}; + } // namespace std #endif // TVM_AUTO_SCHEDULER_LOOP_STATE_H_ diff --git a/src/auto_scheduler/measure.h b/include/tvm/auto_scheduler/measure.h similarity index 93% rename from src/auto_scheduler/measure.h rename to include/tvm/auto_scheduler/measure.h index 02d6e879a1cda..83d7c8d0d3e94 100644 --- a/src/auto_scheduler/measure.h +++ b/include/tvm/auto_scheduler/measure.h @@ -23,26 +23,28 @@ * These functions are responsible for building the tvm module, uploading it to remote devices, * recording the running time costs, and checking the correctness of the output. * - * We separate the measurement into two steps: build and run. + * The measurement is separated into two steps: build and run. * A builder builds the executable binary files and a runner runs the binary files to get the * measurement results. The flow of data structures is * * `ProgramBuilder` `ProgramRunner` * `MeasureInput` -----------------> `BuildResult` ----------------> `MeasureResult` * - * We implement these in python to utilize python's multiprocessing and error handling. + * The core functions is implemented in python to utilize python's multiprocessing + * and error handling (see also `python/tvm/auto_scheduler/measure.py`). + * This c++ file is just a wrapper for the python functions. */ #ifndef TVM_AUTO_SCHEDULER_MEASURE_H_ #define TVM_AUTO_SCHEDULER_MEASURE_H_ +#include +#include + #include #include #include -#include "loop_state.h" -#include "search_task.h" - namespace tvm { namespace auto_scheduler { @@ -209,7 +211,7 @@ class MeasureCallbackNode : public Object { public: /*! * \brief Callback function that will be called on measurement input/result pairs - * after measurement. + * after each measurement batch. * \param policy The current search policy. * \param inputs An Array of MeasureInput. * \param results An Array of MeasureResult. @@ -234,7 +236,7 @@ class MeasureCallback : public ObjectRef { /*! \brief ProgramBuilder that builds the programs */ class ProgramBuilderNode : public Object { public: - /*! \brief The number of tasks to run in parallel */ + /*! \brief The number of build processes to run in parallel */ int n_parallel; /*! \brief Timeout of a build */ int timeout; @@ -323,15 +325,15 @@ class LocalBuilder : public ProgramBuilder { * \brief The constructor. * \param timeout The timeout limit (in second) for each build thread. * This will be used in a wrapper of the multiprocessing.Process.join(). - * \param n_parallel Number of threads used to build in parallel. - * \param build_func The name of registered build function. + * \param n_parallel The number of threads used to build in parallel. + * \param build_func The name of the registered build function. */ LocalBuilder(int timeout, int n_parallel, const String& build_func); TVM_DEFINE_OBJECT_REF_METHODS(LocalBuilder, ProgramBuilder, LocalBuilderNode); }; -/*! \brief LocalRunner that uses local CPU/GPU to measures the time cost of programs */ +/*! \brief LocalRunner that uses local CPU/GPU to measure the time cost of programs */ class LocalRunnerNode : public ProgramRunnerNode { public: Array Run(const Array& inputs, @@ -373,13 +375,12 @@ class RPCRunnerNode : public ProgramRunnerNode { String key; /*! \brief The host address of the RPC Tracker. */ String host; - /*! \brief The port of RPC Tracker. */ + /*! \brief The port of the RPC Tracker. */ int port; /*! \brief The priority of this run request, larger is more prior. */ int priority; /*! \brief The number of tasks run in parallel. */ int n_parallel; - /*! \brief The number of times to run the generated code for taking average. */ Array Run(const Array& inputs, const Array& build_results, int verbose) final; @@ -395,10 +396,11 @@ class RPCRunnerNode : public ProgramRunnerNode { class RPCRunner : public ProgramRunner { public: /*! - * \brief The constructor. + * \brief The constructor. See the corresponding class in python/tvm/auto_scheduler/measure.py + * for more detailed parameter explaination. * \param key The key of the device registered in the RPC tracker. * \param host The host address of the RPC Tracker. - * \param prot The port of RPC Tracker. + * \param port The port of RPC Tracker. * \param priority The priority of this run request, larger is more prior. * \param n_parallel The number of tasks run in parallel. * \param timeout Timeout of a run. @@ -415,7 +417,7 @@ class RPCRunner : public ProgramRunner { /*! * \brief Measurer that measures the time costs of tvm programs - * This class combines ProgramBuilder and ProgramRunner, and provides a simpler API */ + * This class combines ProgramBuilder and ProgramRunner and provides a simpler API */ class ProgramMeasurerNode : public Object { public: /*! \brief Measured programs counter. */ @@ -483,7 +485,7 @@ class ProgramMeasurer : public ObjectRef { * \param callbacks MeasureCallback to be called after each measure batch. * \param verbose Verbosity level. 0 for silent, 1 to output information during program * measuring. - * \param max_continous_error The number of max continuous error. + * \param max_continous_error The number of allowed maximum continuous error. */ ProgramMeasurer(ProgramBuilder builder, ProgramRunner runner, Optional> callbacks, int verbose, diff --git a/src/auto_scheduler/measure_record.h b/include/tvm/auto_scheduler/measure_record.h similarity index 83% rename from src/auto_scheduler/measure_record.h rename to include/tvm/auto_scheduler/measure_record.h index 1cfeab07a4009..fa8fe2b1b4556 100644 --- a/src/auto_scheduler/measure_record.h +++ b/include/tvm/auto_scheduler/measure_record.h @@ -18,26 +18,26 @@ */ /*! - * \file auto_scheduler/measure_record.h - * \brief Json serialization format for dumping and loading tuning records. + * \file tvm/auto_scheduler/measure_record.h + * \brief Json serialization format for dumping and loading measurement records. */ #ifndef TVM_AUTO_SCHEDULER_MEASURE_RECORD_H_ #define TVM_AUTO_SCHEDULER_MEASURE_RECORD_H_ +#include + #include #include #include -#include "measure.h" - namespace tvm { namespace auto_scheduler { /*! \brief Callback for logging the input and results of measurements to file */ class RecordToFileNode : public MeasureCallbackNode { public: - /*! \brief File name for this callback to write log to. */ + /*! \brief The name of output file. */ String filename; void Callback(const SearchPolicy& policy, const Array& inputs, @@ -55,7 +55,7 @@ class RecordToFile : public MeasureCallback { public: /*! * \brief The constructor. - * \param filename File name for this callback to write log. + * \param filename The name of output file */ explicit RecordToFile(String filename); @@ -65,7 +65,7 @@ class RecordToFile : public MeasureCallback { /*! \brief Log reader to load step logs from a file.*/ class RecordReaderNode : public Object { public: - /*! \brief File name for this reader to load log from. */ + /*! \brief The name of input file. */ String filename; /*! \brief The reading file stream. */ std::ifstream infile; @@ -92,7 +92,7 @@ class RecordReaderNode : public Object { TVM_DECLARE_FINAL_OBJECT_INFO(RecordReaderNode, Object); private: - /*! \brief A string object to store the next line. */ + /*! \brief A string storing the current line. */ std::string cur_line_; }; @@ -104,7 +104,7 @@ class RecordReader : public ObjectRef { public: /*! * \brief The constructor. - * \param filename File name for this callback to write log. + * \param filename The name of input file */ explicit RecordReader(String filename); @@ -112,7 +112,7 @@ class RecordReader : public ObjectRef { }; /*! - * \brief Write measure records to an output stream. + * \brief Append measure records to an output stream. * \param os A pointer to a output stream. * \param inputs The MeasureInputs to be written. * \param results The MeasureResults to be written. @@ -122,10 +122,10 @@ void WriteMeasureRecords(std::ostream* os, const Array& inputs, /*! * \brief Read one measure record from a string. - * \param str The record string to be extract. - * \param inp A pointer to a MeasureInputNode, this is used as output. - * \param res A pointer to a MeasureResultNode, this is used as output. - * \param log_version A pointer to a log version string. + * \param str The record string to be parsed. + * \param inp A pointer to a MeasureInputNode used to store the return value. + * \param res A pointer to a MeasureResultNode used to store the return value. + * \param log_version A pointer to a string used to store the log version. */ void ReadMeasureRecord(const std::string& str, MeasureInputNode* inp, MeasureResultNode* res, std::string* log_version); diff --git a/src/auto_scheduler/search_policy/search_policy.h b/include/tvm/auto_scheduler/search_policy.h similarity index 79% rename from src/auto_scheduler/search_policy/search_policy.h rename to include/tvm/auto_scheduler/search_policy.h index 70f94ad65b94d..457aca1e8f2ec 100644 --- a/src/auto_scheduler/search_policy/search_policy.h +++ b/include/tvm/auto_scheduler/search_policy.h @@ -18,11 +18,11 @@ */ /*! - * \file auto_scheduler/search_policy/search_policy.h + * \file tvm/auto_scheduler/search_policy.h * \brief The base class of search policies, including the abstract definition of search policy and * other supporting data structures. * - * The basic schedule search process for TVM Auto-scheduler is design to be: + * The basic schedule search process for the auto-scheduler is design to be: * `Program sampling` -> `Performance Tuning`. * * In `Program sampling`, we use some predefined precise or heuristic rules to generate several @@ -31,7 +31,7 @@ * * Candidate schedules are measured against the specific hardware target. * - * \note Adding a new search policy. + * \note How to add a new search policy. * In design, there's no need for users to implement their own search policy, our formal search * policy(will be brought later) should be enough to cover most use cases. Meanwhile, a custom rule * mechanism will be provided to enable user-defined template search to serve the same functionality @@ -48,16 +48,15 @@ * during the search process. */ -#ifndef TVM_AUTO_SCHEDULER_SEARCH_POLICY_SEARCH_POLICY_H_ -#define TVM_AUTO_SCHEDULER_SEARCH_POLICY_SEARCH_POLICY_H_ +#ifndef TVM_AUTO_SCHEDULER_SEARCH_POLICY_H_ +#define TVM_AUTO_SCHEDULER_SEARCH_POLICY_H_ +#include #include #include #include -#include "../search_task.h" - namespace tvm { namespace auto_scheduler { @@ -110,16 +109,16 @@ class SearchPolicyNode : public Object { /*! * \brief Do schedule search for a task. Takes the SearchTask as input and returns the best state - * get during the search process. - * \param task The SearchTask or workload key for the computation declaration - * \param num_measure_trials Total schedules to be tried during this search. - * \param early_stopping Early stop if no better schedule is found. - * \param num_measures_per_round Max measure batch in one search round. + * found during the search. + * \param task The SearchTask for the computation declaration + * \param num_measure_trials The number of total measurement trials. + * \param early_stopping Stops the tuning early if no improvement after n measurements. + * \param num_measures_per_round The number of programs to be measured at each search round. * \param verbose Verbose level. 0 for silent, 1 to output information during schedule * search. - * \param measurer A ProgramMeasurer which packs ProgramBuilder & ProgramRunner inside. + * \param measurer A ProgramMeasurer to build and measure programs * \param pre_search_callbacks SearchCallback to be called before schedule search. - * \return The best state get. + * \return The best state found. */ virtual State Search(SearchTask task, int num_measure_trials, int early_stopping, int num_measures_per_round, int verbose, ProgramMeasurer measurer, @@ -137,16 +136,12 @@ class SearchPolicyNode : public Object { protected: /*! * \brief The set of already measured states. - * During the schedule search process, we may generate `equal states` through different search - * branches. (Equal States: 1. the transform steps are totally the same; 2. even with different - * steps, two states may still result in a same schedule. e.g. To split a axis with extent 512 - * to 3 parts [8, 16, 4]. We can split from inner to outter by factors [16, 4], while we can - * get a same result to split from outter to inner by factors [8, 16]) * We store the string format of a state for redundancy check. This is used to make sure a * measured state will never be measured again. */ std::unordered_set measured_states_set_; - /*! \brief The array of already measured states. This can be used in evolutionary search. */ + /*! \brief The array of already measured states. + * The good states can be used as the initial population in evolutionary search. */ std::vector measured_states_vector_; /*! \brief The throughputs of already measured states */ std::vector measured_states_throughputs_; @@ -164,4 +159,4 @@ class SearchPolicy : public ObjectRef { } // namespace auto_scheduler } // namespace tvm -#endif // TVM_AUTO_SCHEDULER_SEARCH_POLICY_SEARCH_POLICY_H_ +#endif // TVM_AUTO_SCHEDULER_SEARCH_POLICY_H_ diff --git a/src/auto_scheduler/search_task.h b/include/tvm/auto_scheduler/search_task.h similarity index 97% rename from src/auto_scheduler/search_task.h rename to include/tvm/auto_scheduler/search_task.h index ca313500cc8fd..85154b5e406b0 100644 --- a/src/auto_scheduler/search_task.h +++ b/include/tvm/auto_scheduler/search_task.h @@ -25,16 +25,15 @@ #ifndef TVM_AUTO_SCHEDULER_SEARCH_TASK_H_ #define TVM_AUTO_SCHEDULER_SEARCH_TASK_H_ +#include #include -#include "compute_dag.h" - namespace tvm { namespace auto_scheduler { class HardwareParams; -/*! \brief The parameters of target hardware used to guide the search process of SearchPolicy. */ +/*! \brief The parameters of target hardware used to guide the SearchPolicy. */ class HardwareParamsNode : public Object { public: /*! \brief The number of cores. */ diff --git a/src/auto_scheduler/transform_step.h b/include/tvm/auto_scheduler/transform_step.h similarity index 98% rename from src/auto_scheduler/transform_step.h rename to include/tvm/auto_scheduler/transform_step.h index ce3ca50ffae68..b23137a9ba5d8 100644 --- a/src/auto_scheduler/transform_step.h +++ b/include/tvm/auto_scheduler/transform_step.h @@ -19,10 +19,10 @@ /*! * \file auto_scheduler/transform_step.h - * \brief Transformation steps. For each schedule primitive, there is a corresponding transform - * step. + * \brief Transformation steps. These steps are used to manipulate the LoopState. + * They are similar to the schedule primitives in te::Stage. * - * \note To add a new transform step: + * \note How to add a new transform step: * Take fuse step for example: * 1. Define class `FuseStepNode`, `FuseStep` in `transform_steps.h`, and implement its first * construction function `FuseStep::FuseStep()` in `transform_steps.cc`. @@ -51,8 +51,6 @@ #include #include -#include "utils.h" - namespace tvm { namespace auto_scheduler { @@ -187,7 +185,6 @@ Step StepReadFromRecord(dmlc::JSONReader* reader); * \param step The step to be applied to State. * \param state A mutable pointer to State. * \param dag The original ComputeDAG of this state. - * \return The iterator result after annotate. */ void StepApplyToState(const Step& step, State* state, const ComputeDAG& dag); @@ -209,7 +206,7 @@ void StepApplyToSchedule(const Step& step, Array* stages, StageToAxes String StepPrintAsPythonAPI(const Step& step, Array* stages, StageToAxesMap* stage_to_axes); -/********** Primitives working on single stage **********/ +/********** Steps working on single stage **********/ /*! * \brief Annotation step that corresponds to vectorize, parallel, unroll and thread binding. @@ -478,7 +475,7 @@ class SplitStep : public Step { TVM_DEFINE_OBJECT_REF_METHODS(SplitStep, Step, SplitStepNode); }; -/********** Primitives working on multiple stages **********/ +/********** Steps working on multiple stages **********/ /*! \brief Compute at step that corresponds to te::Stage::compute_at */ class ComputeAtStepNode : public StepNode { diff --git a/python/tvm/auto_scheduler/auto_schedule.py b/python/tvm/auto_scheduler/auto_schedule.py index d45dbf8d0aaa4..52aa62baf56f1 100644 --- a/python/tvm/auto_scheduler/auto_schedule.py +++ b/python/tvm/auto_scheduler/auto_schedule.py @@ -57,7 +57,7 @@ def __init__(self, num_cores, vector_unit_bytes, cache_line_bytes): @tvm._ffi.register_object("auto_scheduler.SearchTask") class SearchTask(Object): - """ The computation information and hardware parameters for a specific schedule search task. + """ The computation information and hardware parameters for a schedule search task. Parameters ---------- @@ -158,9 +158,6 @@ def __init__(self, num_measure_trials=0, early_stopping=None, num_measures_per_r def auto_schedule(task, search_policy='default', tuning_options=None): """ Do auto scheduling for a computation declaration. - The task parameter can be a `string` as workload_key, or directly - passing a `SearchTask` as input. - Parameters ---------- task : SearchTask diff --git a/python/tvm/auto_scheduler/workload_registry.py b/python/tvm/auto_scheduler/workload_registry.py index 36c2037810732..045720a037eac 100644 --- a/python/tvm/auto_scheduler/workload_registry.py +++ b/python/tvm/auto_scheduler/workload_registry.py @@ -95,7 +95,7 @@ def make_workload_key(func, args): Returns ------- - workload_key : Str + workload_key : str The workload key of the function. """ global WORKLOAD_FUNC_REGISTRY diff --git a/src/auto_scheduler/auto_schedule.cc b/src/auto_scheduler/auto_schedule.cc index b515b3accf7ab..c537ca702b9da 100644 --- a/src/auto_scheduler/auto_schedule.cc +++ b/src/auto_scheduler/auto_schedule.cc @@ -24,8 +24,7 @@ * schedule after search process. */ -#include "auto_schedule.h" - +#include #include namespace tvm { diff --git a/src/auto_scheduler/compute_dag.cc b/src/auto_scheduler/compute_dag.cc index ccea18c80c0fb..92239e51101a6 100644 --- a/src/auto_scheduler/compute_dag.cc +++ b/src/auto_scheduler/compute_dag.cc @@ -22,8 +22,8 @@ * \brief Compute declaration graph and its related analysis tools. */ -#include "compute_dag.h" - +#include +#include #include #include #include @@ -37,7 +37,6 @@ #include #include -#include "loop_state.h" #include "utils.h" namespace tvm { diff --git a/src/auto_scheduler/loop_state.cc b/src/auto_scheduler/loop_state.cc index bfe547864ed10..35d899ad561fe 100644 --- a/src/auto_scheduler/loop_state.cc +++ b/src/auto_scheduler/loop_state.cc @@ -23,14 +23,13 @@ * see auto_scheduler/loop_state.h for more explanation. */ -#include "loop_state.h" - +#include +#include #include #include #include -#include "transform_step.h" #include "utils.h" namespace tvm { diff --git a/src/auto_scheduler/measure.cc b/src/auto_scheduler/measure.cc index 6198f60da5a6e..e249f7bd7d286 100644 --- a/src/auto_scheduler/measure.cc +++ b/src/auto_scheduler/measure.cc @@ -22,8 +22,7 @@ * \brief Distributed measurement infrastructure to measure the runtime costs of tensor programs. */ -#include "measure.h" - +#include #include #include diff --git a/src/auto_scheduler/measure_record.cc b/src/auto_scheduler/measure_record.cc index 39f9ad86c958e..02f244f93de51 100644 --- a/src/auto_scheduler/measure_record.cc +++ b/src/auto_scheduler/measure_record.cc @@ -22,9 +22,10 @@ * \brief Json serialization format for dumping and loading tuning records. */ -#include "measure_record.h" - #include +#include +#include +#include #include #include @@ -33,8 +34,6 @@ #include #include -#include "loop_state.h" -#include "transform_step.h" #include "utils.h" // Json serialization handler for MeasureInput, MeasureResult diff --git a/src/auto_scheduler/search_policy/empty_policy.cc b/src/auto_scheduler/search_policy/empty_policy.cc index 1886203593a91..4c85af486a610 100644 --- a/src/auto_scheduler/search_policy/empty_policy.cc +++ b/src/auto_scheduler/search_policy/empty_policy.cc @@ -24,10 +24,9 @@ #include "empty_policy.h" +#include #include -#include "../measure.h" - namespace tvm { namespace auto_scheduler { diff --git a/src/auto_scheduler/search_policy/empty_policy.h b/src/auto_scheduler/search_policy/empty_policy.h index 4ccc9c1042eaa..ef7d38ddf1166 100644 --- a/src/auto_scheduler/search_policy/empty_policy.h +++ b/src/auto_scheduler/search_policy/empty_policy.h @@ -26,8 +26,8 @@ #ifndef TVM_AUTO_SCHEDULER_SEARCH_POLICY_EMPTY_POLICY_H_ #define TVM_AUTO_SCHEDULER_SEARCH_POLICY_EMPTY_POLICY_H_ -#include "../loop_state.h" -#include "search_policy.h" +#include +#include namespace tvm { namespace auto_scheduler { diff --git a/src/auto_scheduler/search_policy/search_policy.cc b/src/auto_scheduler/search_policy/search_policy.cc index fba5155edaeab..764b0a7fb97af 100644 --- a/src/auto_scheduler/search_policy/search_policy.cc +++ b/src/auto_scheduler/search_policy/search_policy.cc @@ -22,8 +22,7 @@ * \brief The base class of search policies. */ -#include "search_policy.h" - +#include #include namespace tvm { diff --git a/src/auto_scheduler/search_task.cc b/src/auto_scheduler/search_task.cc index 912a310465409..9cc21f2dfedcc 100644 --- a/src/auto_scheduler/search_task.cc +++ b/src/auto_scheduler/search_task.cc @@ -22,8 +22,7 @@ * \brief Meta information and hardware parameters for a search task. */ -#include "search_task.h" - +#include #include #include diff --git a/src/auto_scheduler/transform_step.cc b/src/auto_scheduler/transform_step.cc index 6c672a5215f2d..b1b3b94370066 100644 --- a/src/auto_scheduler/transform_step.cc +++ b/src/auto_scheduler/transform_step.cc @@ -19,12 +19,12 @@ /*! * \file auto_scheduler/transform_step.cc - * \brief Transformation steps. For each schedule primitive, there is a corresponding transform - * step. + * \brief Transformation steps. These steps are used to manipulate the LoopState. + * They are similar to the schedule primitives in te::Stage. */ -#include "transform_step.h" - +#include +#include #include #include @@ -32,7 +32,6 @@ #include #include -#include "loop_state.h" #include "utils.h" namespace tvm { @@ -80,6 +79,7 @@ Step StepReadFromRecord(dmlc::JSONReader* reader) { } void StepApplyToState(const Step& step, State* state, const ComputeDAG& dag) { + // We need this runtime dispatcher because different steps have different function signatures if (auto ps = step.as()) { ps->ApplyToState(state); } else if (auto ps = step.as()) { @@ -101,6 +101,7 @@ void StepApplyToState(const Step& step, State* state, const ComputeDAG& dag) { void StepApplyToSchedule(const Step& step, Array* stages, StageToAxesMap* stage_to_axes) { + // We need this runtime dispatcher because different steps have different function signatures if (auto ps = step.as()) { ps->ApplyToSchedule(stages, stage_to_axes); } else if (auto ps = step.as()) { @@ -122,6 +123,7 @@ void StepApplyToSchedule(const Step& step, Array* stages, String StepPrintAsPythonAPI(const Step& step, Array* stages, StageToAxesMap* stage_to_axes) { + // We need this runtime dispatcher because different steps have different function signatures if (auto ps = step.as()) { return ps->PrintAsPythonAPI(stages, stage_to_axes); } else if (auto ps = step.as()) { @@ -142,7 +144,7 @@ String StepPrintAsPythonAPI(const Step& step, Array* stages, return ""; } -/********** Primitives working on single stage **********/ +/********** Steps working on single stage **********/ /********** Annotation **********/ AnnotationStep::AnnotationStep(int stage_id, int iter_id, IteratorAnnotation ann) { @@ -741,7 +743,7 @@ String SplitStepNode::PrintAsPythonAPI(Array* stages, return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, lengths, inner_to_outer); } -/********** Primitives working on multiple stages **********/ +/********** Steps working on multiple stages **********/ /********** Compute At **********/ ComputeAtStep::ComputeAtStep(int stage_id, int target_stage_id, int target_iter_id) { diff --git a/tests/cpp/auto_scheduler_test.cc b/tests/cpp/auto_scheduler_test.cc index 67e54da43f8a9..9aab47855bb40 100644 --- a/tests/cpp/auto_scheduler_test.cc +++ b/tests/cpp/auto_scheduler_test.cc @@ -20,16 +20,12 @@ #include #include #include +#include #include #include #include -// todo(merrymercy): expose auto_scheduler header files to `include/tvm` -// and do not use relative path here -#include "../../src/auto_scheduler/compute_dag.h" -#include "../../src/auto_scheduler/loop_state.h" - // Compute declaration for test tvm::Array conv2d_nchw_bn_relu_func(int N, int H, int W, int CI, int CO, int kernel_size, int strides, int padding,