Skip to content

Commit

Permalink
move header files and polish comments
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy committed Jul 22, 2020
1 parent 3a8b4b4 commit 56b0187
Show file tree
Hide file tree
Showing 21 changed files with 144 additions and 157 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,29 +18,27 @@
*/

/*!
* \file auto_scheduler/auto_schedule.h
* \brief The user interface of the TVM Auto-scheduler. This is the entry structure to get
* schedule search requirements from upper level (Python API), and returns a high performance
* schedule after search process.
* \file tvm/auto_scheduler/auto_schedule.h
* \brief The user interface of the auto scheduler.
*/

#ifndef TVM_AUTO_SCHEDULER_AUTO_SCHEDULE_H_
#define TVM_AUTO_SCHEDULER_AUTO_SCHEDULE_H_

#include <utility>
#include <tvm/auto_scheduler/measure.h>
#include <tvm/auto_scheduler/search_policy.h>

#include "measure.h"
#include "search_policy/search_policy.h"
#include <utility>

namespace tvm {
namespace auto_scheduler {

/*! \brief Tuning and measurement options. */
class TuningOptionsNode : public Object {
public:
/*! \brief Number of total measurement trials. */
/*! \brief The number of total measurement trials. */
int num_measure_trials;
/*! \brief Stops early the tuning if no improvement after n measurements. */
/*! \brief Stops the tuning early if no improvement after n measurements. */
int early_stopping;
/*! \brief The number of programs to be measured at each search round. */
int num_measures_per_round;
Expand All @@ -51,7 +49,7 @@ class TuningOptionsNode : public Object {
int verbose;
/*! \brief ProgramBuilder which builds the program */
ProgramBuilder builder;
/*! \brief ProgramRunner which runs the program and measure time costs */
/*! \brief ProgramRunner which runs the program and measures time costs */
ProgramRunner runner;
/*! \brief MeasureCallback functions to be called after each measure batch */
Optional<Array<MeasureCallback>> measure_callbacks;
Expand Down Expand Up @@ -81,8 +79,8 @@ class TuningOptions : public ObjectRef {
public:
/*!
* \brief The constructor
* \param num_measure_trials Number of total measurement trials.
* \param early_stopping Stops early the tuning if no improvement after n measurements.
* \param num_measure_trials The number of total measurement trials.
* \param early_stopping Stops the tuning early if no improvement after n measurements.
* \param num_measures_per_round The number of programs to be measured at each search round.
* \param verbose Verbosity level. 0 for silent, 1 to output information during schedule
* search.
Expand All @@ -100,11 +98,11 @@ class TuningOptions : public ObjectRef {
};

/*!
* \brief Auto schedule search for a given compute declaration.
* \brief Run schedule search for a given compute declaration.
* \param task The search task of the compute declaration.
* \param search_policy The search policy to be used for schedule search.
* \param search_policy The search policy to be used.
* \param tuning_options Tuning and measurement options.
* \return A `te::schedule` and the a Array of `te::Tensor` to be used in `tvm.lower` or
* \return A `te::schedule` and the an Array of `te::Tensor` to be used in `tvm.lower` or
* `tvm.build`.
*/
TVM_DLL std::pair<te::Schedule, Array<te::Tensor>> AutoSchedule(SearchTask task,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/*
/*r
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
Expand All @@ -18,8 +18,8 @@
*/

/*!
* \file auto_scheduler/compute_dag.h
* \brief The TVM Auto-scheduler computational graph and related program analyses.
* \file tvm/auto_scheduler/compute_dag.h
* \brief The auto-scheduler's computational graph and related program analyses.
*
* We convert a compute declaration described by `tvm.compute` (could be a single operator or a
* subgraph) to a ComputeDAG. It keeps the input/output tensors of the compute declaration,
Expand All @@ -35,15 +35,15 @@
#ifndef TVM_AUTO_SCHEDULER_COMPUTE_DAG_H_
#define TVM_AUTO_SCHEDULER_COMPUTE_DAG_H_

#include <tvm/auto_scheduler/loop_state.h>
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/te/schedule.h>

#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>

#include "loop_state.h"

namespace tvm {
namespace auto_scheduler {

Expand Down Expand Up @@ -89,25 +89,25 @@ class AccessAnalyzer : public ObjectRef {
* \brief Return whether this operation needs multi-level tiling
* \param op The operation
*/
bool NeedsMultiLevelTiling(const te::Operation& op) const;
TVM_DLL bool NeedsMultiLevelTiling(const te::Operation& op) const;

/*!
* \brief Return whether this operation is an injective operation
* \param op The operation
*/
bool IsInjective(const te::Operation& op) const;
TVM_DLL bool IsInjective(const te::Operation& op) const;

/*!
* \brief Return whether this operation is strictly inlinable
* \param op The operation
*/
bool IsStrictInlineable(const te::Operation& op) const;
TVM_DLL bool IsStrictInlineable(const te::Operation& op) const;

/*!
* \brief Return whether this operation is an output op
* \param op The operation
*/
bool IsOutput(const te::Operation& op) const;
TVM_DLL bool IsOutput(const te::Operation& op) const;

/*!
* \brief Get all consumers of on operation
Expand All @@ -116,8 +116,10 @@ class AccessAnalyzer : public ObjectRef {
* \param consumers The return consumer set
* \note This function propagates the relation for inlined ops
*/
void GetConsumers(const State& state, const te::Operation& op,
std::unordered_set<te::Operation, ObjectHash, ObjectEqual>* consumers) const;
TVM_DLL void GetConsumers(
const State& state,
const te::Operation& op,
std::unordered_set<te::Operation, ObjectHash, ObjectEqual>* consumers) const;

/*!
* \brief Get all producers of on operation
Expand All @@ -126,8 +128,10 @@ class AccessAnalyzer : public ObjectRef {
* \param producers The return producer set
* \note This function propagates the relation for inlined ops
*/
void GetProducers(const State& state, const te::Operation& op,
std::unordered_set<te::Operation, ObjectHash, ObjectEqual>* producers) const;
TVM_DLL void GetProducers(
const State& state,
const te::Operation& op,
std::unordered_set<te::Operation, ObjectHash, ObjectEqual>* producers) const;

/*!
* \brief Get all direct producers of on operation
Expand Down Expand Up @@ -167,11 +171,11 @@ class ComputeDAGNode : public Object {
Array<te::Tensor> tensors;
/*! \brief All related operations in topo order. */
Array<te::Operation> ops;
/*! \brief Number of total float operations for this ComputeDAG. */
/*! \brief The number of total float operations for this ComputeDAG. */
double flop_ct;
/*! \brief The initial state without any transform steps. */
State init_state;
/*! \brief Static read-write access analyzer */
/*! \brief The static read-write access analyzer */
AccessAnalyzer access_analyzer;

void VisitAttrs(tvm::AttrVisitor* v) {
Expand All @@ -194,16 +198,17 @@ class ComputeDAG : public ObjectRef {
/*! \brief The constructor.
* \param tensors `te::Tensor`s for a compute declaration.
*/
explicit ComputeDAG(Array<te::Tensor> tensors);
TVM_DLL explicit ComputeDAG(Array<te::Tensor> tensors);

/*!
* \brief Apply the history transform steps from a State to get a TVM schedule.
* \brief Apply the history transform steps to get a TVM schedule.
* \param transform_steps Transform steps of a state.
* \param stages A pointer to a `te::Stage` Array, default to be nullptr.
* Pass a valid pointer if these information needs to be used outside this function.
* \param stage_to_axes A pointer to a StageToAxesMap, default to be nullptr.
* Pass a valid pointer if these information needs to be used outside this function.
* \return A `te.schedule` and the a list of `te.Tensor` to be used in `tvm.lower` or `tvm.build`.
* \param stages The list of stages after applying the steps.
* Pass a valid pointer if this information needs to be used outside this function.
* \param stage_to_axes The map that stores all axes for one stage.
* Pass a valid pointer if this information needs to be used outside this function.
* \return A `te.schedule` and the an Array of `te.Tensor` to be used in `tvm.lower`
* or `tvm.build`.
*/
std::pair<te::Schedule, Array<te::Tensor>> ApplySteps(
const Array<Step>& transform_steps, Array<te::Stage>* stages = nullptr,
Expand All @@ -222,9 +227,9 @@ class ComputeDAG : public ObjectRef {
* The states can lose complete bound information after some transform steps (e.g., compute_at).
* We can call this function to infer and fill all the bound information.
* This function calls TVM InferBound pass internally to get the bound.
* The returned state of this function is guaranteed to have complete iterator extent information.
* \param state The state to.
* \return The State after inferbound.
* The returned state of this function is guaranteed to have complete bound information.
* \param state The input state.
* \return The State with complete bound information
*/
State InferBound(const State& state) const;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,15 @@
#ifndef TVM_AUTO_SCHEDULER_LOOP_STATE_H_
#define TVM_AUTO_SCHEDULER_LOOP_STATE_H_

#include <dmlc/common.h>
#include <tvm/auto_scheduler/transform_step.h>
#include <tvm/runtime/container.h>

#include <functional>
#include <unordered_map>
#include <utility>
#include <vector>

#include "transform_step.h"

namespace tvm {
namespace auto_scheduler {

Expand Down Expand Up @@ -159,10 +159,16 @@ using IterKey = std::pair<int, int>;
*/
class AttachMapNode : public Object {
public:
struct key_hash : public std::function<std::size_t(IterKey)> {
std::size_t operator()(const IterKey& k) const {
return ::dmlc::HashCombine(std::hash<int>()(k.first), std::hash<int>()(k.second));
}
};

/*! \brief A Map to store the mapping of stage to its attached iterator. */
std::unordered_map<StageKey, IterKey> stage_to_attach_iter;
/*! \brief A Map to store the mapping of iterator to the stage attached to it. */
std::unordered_map<IterKey, std::vector<StageKey>> iter_to_attached_stages;
std::unordered_map<IterKey, std::vector<StageKey>, key_hash> iter_to_attached_stages;

static constexpr const char* _type_key = "auto_scheduler.AttachMap";
TVM_DECLARE_FINAL_OBJECT_INFO(AttachMapNode, Object);
Expand Down Expand Up @@ -381,21 +387,11 @@ class State : public ObjectRef {
// Hash and equal function for State
namespace std {

/*! \brief The hash function for auto_scheduler::State. */
template <>
struct hash<::tvm::auto_scheduler::State> {
std::size_t operator()(const ::tvm::auto_scheduler::State& state) const {
return tvm::runtime::ObjectHash()(state.ToStr());
}
};

/*!
* \brief The equal_to function for auto_scheduler::State.
* We use the schedule result(its string format) of a state to check if two states are `euqal`.
* Equal States: 1. the transform steps are totally the same; 2. even with different steps, two
* states may still result in a same schedule. e.g. To split a axis with extent 512 to 3 parts
* [8, 16, 4]. We can split from inner to outter by factors [16, 4], while we can get a same result
* to split from outter to inner by factors [8, 16])
* This function checkes the equality by looking at the lowered string format of states.
* If two states with different transform history have the same lowered string format,
* they will be considered being equal.
*/
template <>
struct equal_to<::tvm::auto_scheduler::State> {
Expand All @@ -405,6 +401,14 @@ struct equal_to<::tvm::auto_scheduler::State> {
}
};

/*! \brief The hash function for auto_scheduler::State. */
template <>
struct hash<::tvm::auto_scheduler::State> {
std::size_t operator()(const ::tvm::auto_scheduler::State& state) const {
return tvm::runtime::ObjectHash()(state.ToStr());
}
};

} // namespace std

#endif // TVM_AUTO_SCHEDULER_LOOP_STATE_H_
Loading

0 comments on commit 56b0187

Please sign in to comment.