Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DOC] minor language use improvements #3317

Merged
merged 3 commits into from
Jun 7, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions docs/dev/codebase_walkthrough.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ When a user invokes graph compilation by ``relay.build(...)`` (or ``nnvm.compile
- Generate a compute expression and a schedule for the operator
- Compile the operator into object code

One of the interesting aspects of TVM codebase is that interoperability between C++ and Python is not unidirectional. Typically, all code that do heavy liftings are implemented in C++, and Python bindings are provided for user interface. This is also true in TVM, but in TVM codebase, C++ code also call into functions defined in a Python module. For example, the convolution operator is implemented in Python, and its implementation is invoked from C++ code in Relay.
One of the interesting aspects of TVM codebase is that interoperability between C++ and Python is not unidirectional. Typically, all code that does heavy lifting is implemented in C++, and Python bindings are provided for the user interface. This is also true in TVM, but in TVM codebase, C++ code can also call into functions defined in a Python module. For example, the convolution operator is implemented in Python, and its implementation is invoked from C++ code in Relay.

*******************************************
Vector Add Example
Expand Down Expand Up @@ -84,7 +84,7 @@ The Node system is the basis of exposing C++ types to frontend languages, includ
args[4]);
});

We use ``TVM_REGISTER_*`` macro to expose C++ functions to frontend languages, in the form of `PackedFunc <https://docs.tvm.ai/dev/runtime.html#packedfunc>`_. ``PackedFunc`` is another mechanism by which TVM implements interoperability between C++ and Python. In particular, this is what makes calling Python functions from the C++ codebase very easy.
We use the ``TVM_REGISTER_*`` macro to expose C++ functions to frontend languages, in the form of a `PackedFunc <https://docs.tvm.ai/dev/runtime.html#packedfunc>`_. A ``PackedFunc`` is another mechanism by which TVM implements interoperability between C++ and Python. In particular, this is what makes calling Python functions from the C++ codebase very easy.

A ``Tensor`` object has an ``Operation`` object associated with it, defined in ``python/tvm/tensor.py``, ``include/tvm/operation.h``, and ``src/tvm/op`` subdirectory. A ``Tensor`` is an output of its ``Operation`` object. Each ``Operation`` object has in turn ``input_tensors()`` method, which returns a list of input ``Tensor`` to it. This way we can keep track of dependencies between ``Operation``.

Expand Down Expand Up @@ -141,7 +141,7 @@ Bound inference is the process where all loop bounds and sizes of intermediate b
.. _InferBound Pass: http://docs.tvm.ai/dev/inferbound.html


``stmt``, which is the output of ``ScheduleOps()``, represents an initial loop nest structure. If you have applied ``reorder`` or ``split`` primitives to your schedule, then the initial loop nest already reflects that changes. ``ScheduleOps()`` is defined in ``src/schedule/schedule_ops.cc``.
``stmt``, which is the output of ``ScheduleOps()``, represents an initial loop nest structure. If you have applied ``reorder`` or ``split`` primitives to your schedule, then the initial loop nest already reflects those changes. ``ScheduleOps()`` is defined in ``src/schedule/schedule_ops.cc``.

Next, we apply a number of lowering passes to ``stmt``. These passes are implemented in ``src/pass`` subdirectory. For example, if you have applied ``vectorize`` or ``unroll`` primitives to your schedule, they are applied in loop vectorization and unrolling passes below.

Expand Down Expand Up @@ -173,7 +173,7 @@ Code generation is done by ``build_module()`` function, defined in ``python/tvm/
}


``Build()`` function looks up the code generator for the given target in the ``PackedFunc`` registry, and invokes the function found. For example, ``codegen.build_cuda`` function is registered in ``src/codegen/build_cuda_on.cc``, like this:
The ``Build()`` function looks up the code generator for the given target in the ``PackedFunc`` registry, and invokes the function found. For example, ``codegen.build_cuda`` function is registered in ``src/codegen/build_cuda_on.cc``, like this:

::

Expand All @@ -182,9 +182,9 @@ Code generation is done by ``build_module()`` function, defined in ``python/tvm/
*rv = BuildCUDA(args[0]);
});

``BuildCUDA()`` above generates CUDA kernel source from the lowered IR using ``CodeGenCUDA`` class defined in ``src/codegen/codegen_cuda.cc``, and compile the kernel using NVRTC. If you target a backend that uses LLVM, which includes x86, ARM, NVPTX and AMDGPU, code generation is done primarily by ``CodeGenLLVM`` class defined in ``src/codegen/llvm/codegen_llvm.cc``. ``CodeGenLLVM`` translates TVM IR into LLVM IR, runs a number of LLVM optimization passes, and generates target machine code.
The ``BuildCUDA()`` above generates CUDA kernel source from the lowered IR using ``CodeGenCUDA`` class defined in ``src/codegen/codegen_cuda.cc``, and compile the kernel using NVRTC. If you target a backend that uses LLVM, which includes x86, ARM, NVPTX and AMDGPU, code generation is done primarily by ``CodeGenLLVM`` class defined in ``src/codegen/llvm/codegen_llvm.cc``. ``CodeGenLLVM`` translates TVM IR into LLVM IR, runs a number of LLVM optimization passes, and generates target machine code.

``Build()`` function in ``src/codegen/codegen.cc`` returns a ``runtime::Module`` object, defined in ``include/tvm/runtime/module.h`` and ``src/runtime/module.cc``. A ``Module`` object is a container for the underlying target specific ``ModuleNode`` object. Each backend implements a subclass of ``ModuleNode`` to add target specific runtime API calls. For example, the CUDA backend implements ``CUDAModuleNode`` class in ``src/runtime/cuda/cuda_module.cc``, which manages CUDA driver API. ``BuildCUDA()`` function above wraps ``CUDAModuleNode`` with ``runtime::Module`` and return it to the Python side. The LLVM backend implements ``LLVMModuleNode`` in ``src/codegen/llvm/llvm_module.cc``, which handles JIT execution of compiled code. Other subclasses of ``ModuleNode`` can be found under subdirectories of ``src/runtime`` corresponding to each backend.
The ``Build()`` function in ``src/codegen/codegen.cc`` returns a ``runtime::Module`` object, defined in ``include/tvm/runtime/module.h`` and ``src/runtime/module.cc``. A ``Module`` object is a container for the underlying target specific ``ModuleNode`` object. Each backend implements a subclass of ``ModuleNode`` to add target specific runtime API calls. For example, the CUDA backend implements ``CUDAModuleNode`` class in ``src/runtime/cuda/cuda_module.cc``, which manages the CUDA driver API. The ``BuildCUDA()`` function above wraps ``CUDAModuleNode`` with ``runtime::Module`` and return it to the Python side. The LLVM backend implements ``LLVMModuleNode`` in ``src/codegen/llvm/llvm_module.cc``, which handles JIT execution of compiled code. Other subclasses of ``ModuleNode`` can be found under subdirectories of ``src/runtime`` corresponding to each backend.

The returned module, which can be thought of as a combination of a compiled function and a device API, can be invoked on TVM's NDArray objects.

Expand Down Expand Up @@ -243,4 +243,4 @@ The ``PackedFunc``'s overloaded ``operator()`` will be called, which in turn cal
}
};

This concludes an overview of how TVM compiles and executes a function. Although we did not detail TOPI or Relay, at the end all neural network operators go through the same compilation process as above. You are encouraged to dive into the details of the rest of the codebase.
This concludes an overview of how TVM compiles and executes a function. Although we did not detail TOPI or Relay, in the end, all neural network operators go through the same compilation process as above. You are encouraged to dive into the details of the rest of the codebase.
2 changes: 1 addition & 1 deletion tutorials/language/schedule_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@
######################################################################
# compute_at
# ----------
# For a schedule consists of multiple operators, TVM will compute
# For a schedule that consists of multiple operators, TVM will compute
# tensors at the root separately by default.
A = tvm.placeholder((m,), name='A')
B = tvm.compute((m,), lambda i: A[i]+1, name='B')
Expand Down
2 changes: 1 addition & 1 deletion tutorials/tensor_expr_get_started.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@

######################################################################
# Finally we bind the iteration axis bx and tx to threads in the GPU
# compute grid. These are GPU specific constructs that allows us
# compute grid. These are GPU specific constructs that allow us
# to generate code that runs on GPU.
#
if tgt == "cuda" or tgt.startswith('opencl'):
Expand Down