From 199c3adf2584430d0196280bad71e2bd0b1d4efa Mon Sep 17 00:00:00 2001 From: Thierry Moreau Date: Wed, 20 Jun 2018 14:16:21 -0700 Subject: [PATCH] [DOC] VTA installation & basic tutorials (#47) --- vta/Makefile | 1 + vta/README.md | 17 +- vta/apps/pynq_rpc/README.md | 73 ----- vta/apps/pynq_rpc/start_rpc_server.sh | 2 +- vta/docs/how_to/install.md | 369 ++++++++++++++++++++- vta/examples/resnet18/pynq/README.md | 123 +------ vta/python/vta/environment.py | 2 +- vta/tutorials/get_started.py | 380 ++++++++++++++++++++- vta/tutorials/matrix_multiply.py | 453 ++++++++++++++++++++++++++ vta/tutorials/matrix_multiply_opt.py | 362 ++++++++++++++++++++ 10 files changed, 1554 insertions(+), 228 deletions(-) delete mode 100644 vta/apps/pynq_rpc/README.md create mode 100644 vta/tutorials/matrix_multiply.py create mode 100644 vta/tutorials/matrix_multiply_opt.py diff --git a/vta/Makefile b/vta/Makefile index c43c49bddaca6..60e93c0bd3477 100644 --- a/vta/Makefile +++ b/vta/Makefile @@ -63,6 +63,7 @@ doc: clean: $(RM) -rf build lib bin *~ */*~ */*/*~ */*/*/*~ */*.o */*/*.o */*/*/*.o + $(RM) -rf cat.jpg quantize_graph.json quantize_params.pkl synset.txt -include build/*.d diff --git a/vta/README.md b/vta/README.md index cd3a490ccb6f3..bb347f02b7984 100644 --- a/vta/README.md +++ b/vta/README.md @@ -7,20 +7,27 @@ VTA(versatile tensor accelerator) is an open-source deep learning accelerator st It is not just an open-source hardware, but is an end to end solution that includes the entire software stack on top of VTA open-source hardware. - The key features include: - Generic, modular open-source hardware - Streamlined workflow to deploy to FPGAs. - - Simulator support -- Driver and JIT runtime for both simulated backend and FPGA. + - Simulator support to protoype compilation passes on regular workstations. +- Driver and JIT runtime for both simulated and FPGA hardware backend. - End to end TVM stack integration - Direct optimization and deploy models from deep learning frameworks via TVM stack. - - Customized and extendible TVM compiler backend - - Flexible RPC support to ease the deployment, you can program it with python :) + - Customized and extendible TVM compiler backend. + - Flexible RPC support to ease the deployment, and program FPGAs with Python VTA is part of our effort on [TVM Stack](http://www.tvmlang.org/). +VTA Installation +---------------- +To get started with VTA, please follow the [Installation Guide](docs/how_to/install.md) + +ResNet-18 Inference Example +--------------------------- +To offload ResNet-18 inference, follow the [ResNet-18 Guide](examples/resnet18/pynq/README.md) + License ------- © Contributors, 2018. Licensed under an [Apache-2.0](https://github.com/tmoreau89/vta/blob/master/LICENSE) license. diff --git a/vta/apps/pynq_rpc/README.md b/vta/apps/pynq_rpc/README.md deleted file mode 100644 index bb5f9b0b6c0be..0000000000000 --- a/vta/apps/pynq_rpc/README.md +++ /dev/null @@ -1,73 +0,0 @@ -# PYNQ RPC Server for VTA - -This guide describes how to setup a Pynq-based RPC server to accelerate deep learning workloads with VTA. - -## Pynq Setup - -Follow the getting started tutorial for the [Pynq board](http://pynq.readthedocs.io/en/latest/getting_started.html). -* This assumes that you've downloaded the latest Pynq image, PYNQ-Z1 v2.1 (released 21 Feb 2018). -* For this RPC setup, follow the ["Connect to a Computer"](http://pynq.readthedocs.io/en/latest/getting_started.html#connect-to-a-computer) Pynq setup instructions. -* To be able to talk to the board, you'll need to make sure that you've followed the steps to [assign a static IP address](http://pynq.readthedocs.io/en/latest/appendix.html#assign-your-computer-a-static-ip) - -Make sure that you can talk to your Pynq board successfully: -```bash -ping 192.168.2.99 -``` - -When ssh-ing onto the board, the password for the `xilinx` username is `xilinx`. - -For convenience let's go ahead and mount the Pynq board's file system to easily access it (this will require sshfs to be installed): -```bash -mkdir -sshfs xilinx@192.168.2.99:/home/xilinx -``` - -## Pynq TVM & VTA installation - -On your **host PC**, go to the `` directory of your Pynq board file system. -```bash -cd -``` - -From there, clone the VTA repository: -```bash -git clone git@github.com:uwsaml/vta.git --recursive -``` - -Now, ssh into your **Pynq board** to build the TVM runtime with the following commands. This build should take about 5 minutes. -```bash -ssh xilinx@192.168.2.99 # ssh if you haven't done so -cd ~/vta/nnvm/tvm -cp make/config.mk . -echo USE_RPC=1 >> config.mk -make runtime -j2 -``` - -We're now ready to build the Pynq RPC server on the Pynq board, which should take less than 30 seconds. -```bash -ssh xilinx@192.168.2.99 # ssh if you haven't done so -cd ~/vta -make -j2 -``` - -Add VTA and TVM to PYTHONPATH: -``` -export PYTHONPATH=$PYTHONPATH:/home/xilinx/vta/python:/home/xilinx/vta/nnvm/tvm/python -``` - -The last stage will build the `vta/lib/libvta.so` library file. We are now ready to launch the RPC server on the Pynq. In order to enable the FPGA drivers, we need to run the RPC server with `sudo` privileges. -```bash -ssh xilinx@192.168.2.99 # ssh if you haven't done so -cd ~/vta -sudo PYTHONPATH=$PYTHONPATH ./apps/pynq_rpc/start_rpc_server.sh # pw is xilinx -``` - -You should see the following being displayed when starting the RPC server: -``` -INFO:root:Load additional library /home/xilinx/vta/lib/libvta.so -INFO:root:RPCServer: bind to 0.0.0.0:9091 -``` - -Note that it should be listening on port `9091`. - -To kill the RPC server, just enter the `Ctrl + c` command. diff --git a/vta/apps/pynq_rpc/start_rpc_server.sh b/vta/apps/pynq_rpc/start_rpc_server.sh index fac12e82bb55b..e36d80ccc2603 100755 --- a/vta/apps/pynq_rpc/start_rpc_server.sh +++ b/vta/apps/pynq_rpc/start_rpc_server.sh @@ -1,4 +1,4 @@ #!/bin/bash -export PYTHONPATH=${PYTHONPATH}:/home/xilinx/vta/nnvm/tvm/python:/home/xilinx/vta/python +export PYTHONPATH=${PYTHONPATH}:/home/xilinx/vta/tvm/python:/home/xilinx/vta/python export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/opt/python3.6/lib/python3.6/site-packages/pynq/drivers/ python -m vta.exec.rpc_server diff --git a/vta/docs/how_to/install.md b/vta/docs/how_to/install.md index c6bab2f1a6606..e2ccf63e90801 100644 --- a/vta/docs/how_to/install.md +++ b/vta/docs/how_to/install.md @@ -1,19 +1,368 @@ -Installation Guide -================== -This page gives instructions on how to build and use VTA +Installation Guides +=================== + +We present three installation guides, each extending on the previous one: +1. VTA simulation-only installation +2. VTA hardware testing setup with the [Pynq](http://www.pynq.io/) FPGA development board +3. VTA hardware compilation tool chain installation + +## VTA Simulation-Only Installation + +This first guide details the installation of the VTA package to run hardware simulation tests locally on your development machine (in case you don't own the Pynq FPGA development board). +This guide includes: +1. Software dependences installation +2. Simulation library compilation +3. Python package installation +4. Test examples to ensure that the VTA package was correctly installed + +To get started, clone vta repo from [github](https://github.com/uwsaml/vta). It is important to clone the submodules along with ```--recursive``` option. +```bash +git clone --recursive https://github.com/uwsaml/vta +``` + +### VTA Dependences + +The VTA package depends on several other packages that need to be manually installed beforehand. + +We list the dependences below: +* LLVM 4.0 or newer +* TVM +* MxNet (to run the end-to-end examples) +* Additional python packages + +#### LLVM Installation + +We provide the set of commands to install LLVM 6.0 (stable branch) on Ubuntu Xenial. Note that the [LLVM installation process](apt.llvm.org) can be adapted to different LLVM branches, and operating systems. + +```bash +wget -O - https://apt.llvm.org/llvm-snapshot.gpg.key|sudo apt-key add - +sudo apt-add-repository "deb http://apt.llvm.org/xenial/ llvm-toolchain-xenial-6.0 main" +sudo apt-get update +apt-get install clang-6.0 lldb-6.0 lld-6.0 +``` + +To ensure that LLVM 6.0 was properly installed, check that the following command gives the path to your `llvm-config` binary (you may have to append the version number to the executable name): + +```bash +which llvm-config-6.0 +``` + +#### TVM Installation + +TVM is included as a top-level submodule to VTA, and can be found under `/tvm`. + +Follow the [installation instructions](https://docs.tvm.ai/install/index.html). + +In the 'config.mk' file, make sure that: +* `LLVM_CONFIG` points to the `llvm-config` executable which path was derived in the LLVM installation instructions above (e.g. `LLVM_CONFIG = /usr/bin/llvm-config-6.0`) +* `USE_RPC` should be set to 1 + +For the *Python Package Installation*, we recommend updating your `~/.bashrc` file to extend your `PYTHONPATH` with the TVM Python libraries. +```bash +export PYTHONPATH=/python:/topi/python:${PYTHONPATH} +``` + +#### NNVM Installation + +Clone the NNVM repository from `tqchen` in the directory of your choosing: +```bash +git clone git@github.com:tqchen/nnvm.git --recursive +``` + +To run this example, we rely on a special branch of NNVM `qt`: +```bash +cd +git checkout qt +``` + +Launch the compilation, this takes about a minute on two threads. +```bash +cd +make -j2 +``` + +Finally update your `~/.bashrc` file to include the NNVM python libraries in your `PYTHONPATH`: +```bash +export PYTHONPATH=/python:${PYTHONPATH} +``` + +#### MxNet Installation + +Follow the [MxNet Installation Instructions](https://mxnet.incubator.apache.org) + +#### Python Dependences + +You'll need the following packages to be installed for the example to run properly. You can use `pip` to install those packages: +* `decorator` +* `enum34` +* `Pillow` +* `wget` + +### VTA Shared Library Compilation + +Before building the VTA shared library, the VTA configuration can be modified by changing `config.json` file. +This file provides an architectural specification of the VTA accelerator that can be understood by both the TVM compiler stack and the VTA hardware stack. +It also specifies the TVM compiler target. When `TARGET` is set to `sim`, it tells the TVM compiler to execute the TVM workloads on the VTA simulator. + +To build the simulator library, copy the simulation configuration file `make/sim_sample.json` to the project root. +Next, you can build the VTA simulation dynamic library with `make`. + +```bash +cd +cp make/sim_sample.json config.json +make -j4 +``` + +### VTA Python Package Installation + +The Python package can installed by extending your `PYTHONPATH` environment variable to point to the VTA python library path. +You can run the following line in a terminal, or add it to your `~/.bashrc` file if you plan on using VTA regularly. + +```bash +export PYTHONPATH=/python:${PYTHONPATH} +``` + +### Testing your VTA Simulation Setup + +Finally to ensure that you've properly installed the VTA package, we can run simple unit tests and the ResNet-18 inference example. + +Let's first run the 2D convolution test bench that will only run the ResNet-18 convolution layers. + +```bash +python tests/python/integration/test_benchmark_topi_conv2d.py +``` + +> Note: You'll notice that for every convolution layer, the throughput gets reported in GOPS. These numbers are actually the computational throughput that the simulator achieves, by evaluating the convolution in software. + +Next we can also run the ResNet-18 end to end example in the VTA simulator. +This test will download the following files in your root: +* `cat.jpg` the test image to classify +* `synset.txt` the ImageNet categories +* `quantize_graph.json` the 8-bit ResNet-18 inference NNVM graph +* `quantize_params.plk` the 8-bit ResNet-18 model parameters + +```bash +python examples/resnet18/pynq/imagenet_predict.py +``` +> Note: This will run ResNet inference by offloading the compute-heavy convolution layers to the VTA simulator, and report the top-1 category, and the inference time cost in seconds. + +## VTA Pynq-Based Testing Setup + +This second guide extends the *VTA Simulation-Only Installation* guide above to allow FPGA-based hardware tests of the full TVM and VTA software-hardware stack. +In terms of hardware components you'll need: +* The [Pynq](http://www.pynq.io/) FPGA development board which can be acquired for $200, or $150 for academics from [Digilent](https://store.digilentinc.com/pynq-z1-python-productivity-for-zynq/). +* An Ethernet-to-USB adapter to connect the Pynq board to your development computer. +* An 8+GB micro SD card the (can be ordered with the Pynq dev kit). +* An AC to DC 12V 3A power adapter (can be ordered with the Pynq dev kit). + +This guide includes: +1. Pynq board setup instructions +2. Pynq-side RPC server build and deployment +3. Revisiting the test examples from the *VTA Simulation-Only Installation* guide, this time executing on the Pynq board + +### Pynq Board Setup + +Setup your Pynq board based on the *Getting Started* tutorial for the [Pynq board](http://pynq.readthedocs.io/en/latest/getting_started.html). You should follow the instructions up to and including the *Turning On the PYNQ-Z1* steps (no need to pursue *Getting Started* tutorial beyond this point). +* Make sure that you've downloaded the latest Pynq image, PYNQ-Z1 v2.1 (released 21 Feb 2018), and have imaged your SD card with it. +* For this particular setup, follow the ["Connect to a Computer"](http://pynq.readthedocs.io/en/latest/getting_started.html#connect-to-a-computer) Ethernet setup instructions. + * To be able to talk to the board, make sure to [assign your computer a static IP address](http://pynq.readthedocs.io/en/latest/appendix.html#assign-your-computer-a-static-ip) + +Once the board is powered on and connected to your development host machine, try connecting to it to make sure you've properly set up your Pynq board: +```bash +# To connect to the Pynq board use the [username, password] combo: [xilinx, xilinx] +ssh xilinx@192.168.2.99 +``` + +### Pynq-Side RPC Server Build & Deployment + +Because the direct board-to-computer connection prevents the board from directly accessing the internet, we'll need to mount the Pynq's file system to your development machine's file system with [sshfs](https://www.digitalocean.com/community/tutorials/how-to-use-sshfs-to-mount-remote-file-systems-over-ssh). Next we directly clone the VTA repository into the mountpoint from your development machine. -To get started, clone tvm repo from github. It is important to clone the submodules along, with ```--recursive``` option. ```bash +mkdir +sshfs xilinx@192.168.2.99:/home/xilinx +cd git clone --recursive https://github.com/uwsaml/vta +# When finished, you can leave the moutpoint and unmount the directory +cd ~ +sudo umount +``` + +Now that we've cloned the VTA repository in the Pynq's file system, we can ssh into it and launch the build of the TVM-based RPC server. +The build process should take roughly 5 minutes. + +```bash +ssh xilinx@192.168.2.99 +# Build TVM runtime library (takes 5 mins) +cd /home/xilinx/vta/tvm +mkdir build +cp cmake/config.cmake build/. +cd build +cmake .. +make runtime -j2 +# Build VTA RPC server (takes 1 min) +cd /home/xilinx/vta +sudo ./apps/pynq_rpc/start_rpc_server.sh # pw is 'xilinx' +``` + +You should see the following being displayed when starting the RPC server. In order to run the next examples, you'll need to leave the RPC server running in an `ssh` session. +``` +INFO:root:Load additional library /home/xilinx/vta/lib/libvta.so +INFO:root:RPCServer: bind to 0.0.0.0:9091 +``` + +Tips regarding the Pynq RPC Server: +* The RPC server should be listening on port `9091`. If not, an earlier process might have terminated unexpectedly and it's recommended in this case to just reboot the Pynq, and re-run the RPC server. +* To kill the RPC server, just send the `Ctrl + c` command. You can re-run it with `sudo ./apps/pynq_rpc/start_rpc_server.sh`. +* If unresponsive, the board can be rebooted by power-cycling it with the physical power switch. + +### Testing your VTA Pynq-based Hardware Setup + +Before running the examples you'll need to configure your environment as follows: +```bash +export VTA_PYNQ_RPC_HOST=192.168.2.99 +export VTA_PYNQ_RPC_PORT=9091 +``` + +In addition, you'll need to edit the `config.json` file to indicate that we are targeting the Pynq platform, by setting the `TARGET` field to the `"pynq"` value. Alternatively, you can copy the default `make/config.json` into the VTA root. +> Note: in contrast to our simulation setup, there are no libraries to compile on the host side since the host offloads all of the computation to the Pynq board. + +```bash +cd +cp make/config.json . +``` + +This time again, we will run the 2D convolution testbench. But beforehand, we'll need to program the Pynq's own FPGA with a VTA bitstream, and build the VTA runtime on the Pynq via RPC. The following `test_program_rpc.py` script will perform two operations: +* FPGA programming, by downloading a pre-compiled bitstream from a [VTA bitstream repository](https://github.com/uwsaml/vta-distro) that matches the default `config.json` configuration set by the host, and sending it over to the Pynq via RPC to program the Pynq's FPGA. +* Runtime building on the Pynq, which needs to be run everytime the `config.json` configuration is modified. This ensures that the VTA software runtime that generates the accelerator's executable via just-in-time (JIT) compilation matches the specifications of the VTA design that is programmed on the FPGA. The build process takes about 30 seconds to complete. + +```bash +python tests/python/pynq/test_program_rpc.py +``` + +> Tip: You can track progress of the FPGA programming and the runtime rebuilding steps by looking at the RPC server's logging messages in your Pynq `ssh` session. + +We are now ready to run the 2D convolution testbench for the ResNet-15 workload in hardware. + +```bash +python tests/python/pynq/test_benchmark_conv2d.py ``` -For windows users who use github tools, you can open the git shell, and type the following command. + +The performance metrics measured on the Pynq board will be reported for each convolutional layer. + +Finally, we run the ResNet-18 end-to-end example on the Pynq. + ```bash -git submodule init -git submodule update --init --recursive +python examples/resnet18/pynq/imagenet_predict.py ``` +This will run ResNet inference by offloading the compute-heavy convolution layers to the Pynq's FPGA-based VTA accelerator. The time cost is also measured in seconds here. + +## VTA Hardware Toolchain Installation + +This third and last guide allows users to generate custom VTA bitstreams using free-to-use Xilinx compilation toolchains. + +This guide includes: +1. Xilinx toolchain installation (for Linux) +2. Custom VTA bitstream compilation +3. Running the end to end ResNet-18 test with the new bitstream + +### Xilinx Toolchain Installation -## Build Hardware +We recommend using `Vivado 2017.1` since our scripts have been tested to work on this version of the Xilinx toolchains. Our guide is written for Linux installation. + +You’ll need to install Xilinx’ FPGA compilation toolchain, [Vivado HL WebPACK 2017.1](https://www.xilinx.com/products/design-tools/vivado.html), which a license-free version of the Vivado HLx toolchain. + +#### Obtaining and Launching the Vivado GUI Installer + +1. Go to the [download webpage](https://www.xilinx.com/support/download.html), and download the Linux Self Extracting Web Installer for Vivado HL 2017.1 WebPACK and Editions. +2. You’ll have to sign in with a Xilinx account. This requires a Xilinx account creation that will take 2 minutes. +3. Complete the Name and Address Verification by clicking “Next”, and you will get the opportunity to download a binary file, called `Xilinx_Vivado_SDK_2017.1_0415_1_Lin64.bin`. +4. Now that the file is downloaded, go to your `Downloads` directory, and change the file permissions so it can be executed: +```bash +chmod u+x Xilinx_Vivado_SDK_2017.1_0415_1_Lin64.bin +``` +5. Now you can execute the binary: +```bash +./Xilinx_Vivado_SDK_2017.1_0415_1_Lin64.bin +``` + +#### Xilinx Vivado GUI Installer Steps + +At this point you've launched the Vivado 2017.1 Installer GUI program. + +1. Click “Next” on the *Welcome* screen. +2. Enter your Xilinx User Credentials under “User Authentication” and select the “Download and Install Now” before clicking “Next” on the *Select Install Type* screen. +3. Accept all terms before clicking on “Next” on the *Accept License Agreements* screen. +4. Select the “Vivado HL WebPACK” before clicking on “Next” on the *Select Edition to Install* screen. +5. Under the *Vivado HL WebPACK* screen, before hitting “Next", check the following options (the rest should be unchecked): + * Design Tools -> Vivado Design Suite -> Vivado + * Design Tools -> Vivado Design Suite -> Vivado High Level Synthesis + * Devices -> Production Services -> SoCs -> Zynq-7000 Series +6. Your total download size should be about 3GB and the amount of Disk Space Required 13GB. +7. Set the installation directory before clicking “Next” on the *Select Destination Directory* screen. It might highlight some paths as red - that’s because the installer doesn’t have the permission to write to that directory. In that case select a path that doesn’t require special write permissions (e.g. in your home directory). +8. Hit “Install” under the *Installation Summary* screen. +9. An *Installation Progress Window* will pop-up to track progress of the download and the installation. +10. This process will take about 20-30 minutes depending on your connection speed. +11. A pop-up window will inform you that the installation completed successfully. Click "OK". +12. Finally the *Vivado License Manager* will launch. Select "Get Free ISE WebPACK, ISE/Vivado IP or PetaLinux License" and click "Connect Now" to complete the license registration process. + +#### Environment Setup + +The last step is to update your `~/.bashrc` with the following lines. This will include all of the Xilinx binary paths so you can launch compilation scripts from the command line. +```bash +# Xilinx Vivado 2017.1 environmentexport XILINX_VIVADO=/home/moreau/Xilinx/SDx/2017.1/Vivado +export XILINX_VIVADO=/home/moreau/Xilinx/SDx/2017.1/Vivado +export XILINX_HLS=/home/moreau/Xilinx/SDx/2017.1/Vivado_HLS +export XILINX_SDK=/home/moreau/Xilinx/SDx/2017.1/SDK +export PATH=${XILINX_VIVADO}/bin:${PATH} +export PATH=${XILINX_HLS}/bin:${PATH} +export PATH=${XILINX_SDK}/bin:${PATH} +``` + +### Custom VTA Bitstream Compilation + +High-level parameters are listed under `/make/config.json` and can be customized by the user. For this custom VTA Bitstream Compilation exercise, we'll change the frequency of our design, so it can be clocked a little faster. +* Set the `HW_FREQ` field to `142`. The Pynq board supports 100, 142, 167 and 200MHz clocks. Note that the higher the frequency, the harder it will be to close timing. Increasing the frequency can lead to timing violation and thus faulty hardware. +* Set the `HW_CLK_TARGET` to `6`. This parameters refers to the target clock period in ns passed to HLS - a lower clock period leads to more aggressive pipelining to achieve timing closure at higher frequencies. Technically a 142MHz clock would require a 7ns target, but we intentionally lower the clock target to 6ns to more aggressively pipeline our design. + +Bitstream generation is driven by a top-level `Makefile` under `/hardware/xilinx/`. + +If you just want to simulate the VTA design in software emulation to make sure that it is functional, enter: +```bash +cd /hardware/xilinx +make ip MODE=sim +``` + +If you just want to generate the HLS-based VTA IP cores without launching the entire design place and route, enter: +```bash +make ip +``` +You'll be able to view the HLS synthesis reports under `/build/hardware/xilinx/hls///solution0/syn/report/_csynth.rpt` +> Note: The `` name is a string that summarizes the VTA configuration parameters specified in the `config.json`. The `` name refers to the specific module in the VTA pipeline. + +Finally to run the full hardware compilation and generate the bitstream, run: + +```bash +make +``` + +This process is lenghty, and can take around up to an hour to complete depending on your machine's specs. We recommend setting the `VTA_HW_COMP_THREADS` variable in the Makefile to take full advantage of all the cores on your development machine. + +Once the compilation completes, the generated bitstream can be found under `/build/hardware/xilinx/vivado//export/vta.bit`. + +### End-to-end ResNet-18 Example with the Custom Bitstream + +Let's run the ResNet-18 example with our newly generated bitstream. + +In `/examples/resnet18/pynq/imagenet_predict.py`, change the line: +```python +vta.program_fpga(remote, bitstream=None) +``` +to + +```python +vta.program_fpga(remote, bitstream="/build/hardware/xilinx/vivado//export/vta.bit") +``` -## Build Runtime +Instead of downloading the bitstream from the bitstream repository, the programmer will instead use the custom bitstream you just generated, which is a VTA design clocked at a higher frequency. -## Use VTA Python Compiler Package +Do you observe a noticable performance increase on the ImageNet inference workload? diff --git a/vta/examples/resnet18/pynq/README.md b/vta/examples/resnet18/pynq/README.md index 1906ca082378e..1213d94ec6b1f 100644 --- a/vta/examples/resnet18/pynq/README.md +++ b/vta/examples/resnet18/pynq/README.md @@ -1,127 +1,6 @@ # Resnet-18 Example on Pynq-based VTA Design -In order to run this example you'll need to have: -* VTA installed -* LLVM 4.0 or newer installed -* TVM installed -* NNVM installed -* MxNet installed -* A Pynq-based RPC server running -* Python packages installed - -Required setup time from scratch: ~15 mins. - -## VTA installation - -Clone the VTA repository in the directory of your choosing: -```bash -git clone git@github.com:uwsaml/vta.git --recursive -``` - -Update your `~/.bashrc` file to include the VTA python libraries in your `PYTHONPATH` (don't forget to source the newly modified `.bashrc` file!): -```bash -export PYTHONPATH=/python:${PYTHONPATH} -``` - -## LLVM installation - -We provide the set of commands to install LLVM 6.0 (stable branch) on Ubuntu Xenial. Note that the [LLVM installation process](apt.llvm.org) can be adapted to different LLVM branches, and operating systems/distros. - -```bash -wget -O - https://apt.llvm.org/llvm-snapshot.gpg.key|sudo apt-key add - -sudo apt-add-repository "deb http://apt.llvm.org/xenial/ llvm-toolchain-xenial-6.0 main” -sudo apt-get update -apt-get install clang-6.0 lldb-6.0 lld-6.0 -``` - -To ensure that LLVM 6.0 was properly installed, check that the following command gives the path to your `llvm-config` binary. - -```bash -which llvm-config-6.0 -``` - -## TVM installation - -Clone the TVM repository in the directory of your choosing: -```bash -git clone git@github.com:dmlc/tvm.git --recursive -``` - -TVM is rapidly changing, and to ensure stability, we keep track of working TVM checkpoints. -As of now, the TVM checkpoint `168f099155106d1188dbc54ac00acc02900a3c6f` is known to work with VTA. -```bash -cd -git checkout 168f099155106d1188dbc54ac00acc02900a3c6f -``` - -Before building TVM, copy the `make/config.mk` file into the root TVM directory: -```bash -cd -cp make/config.mk . -``` - -In the 'config.mk' file sure that: -* `LLVM_CONFIG` points to the `llvm-config` executable which path was derived in the TVM installation instructions above (e.g. `LLVM_CONFIG = /usr/bin/llvm-config-6.0`) -* `USE_RPC` should be set to 1 - -Launch the compilation, this takes about 5-10 minutes on two threads. -```bash -cd -make -j2 -``` - -Finally update your `~/.bashrc` file to include the TVM python libraries in your `PYTHONPATH` (don't forget to source the newly modified `.bashrc` file!): -```bash -export PYTHONPATH=/python:/topi/python:${PYTHONPATH} -``` - -## NNVM installation - -Clone the NNVM repository from `tqchen` in the directory of your choosing: -```bash -git clone git@github.com:tqchen/nnvm.git --recursive -``` - -To run this example, we rely on a special branch of NNVM `qt`: -```bash -cd -git checkout qt -``` - -Launch the compilation, this takes about a minute on two threads. -```bash -cd -make -j2 -``` - -Finally update your `~/.bashrc` file to include the NNVM python libraries in your `PYTHONPATH` (don't forget to source the newly modified `.bashrc` file!): -```bash -export PYTHONPATH=/python:${PYTHONPATH} -``` - -## MxNet Installation - -Follow the [MxNet Installation Instructions](https://mxnet.incubator.apache.org) - -## Pynq RPC Server Setup - -Follow the [Pynq RPC Server Guide](https://github.com/uwsaml/vta/tree/master/apps/pynq_rpc/README.md) - -## Python packages - -You'll need the following packages to be installed for the example to run properly. You can use `pip` to install those packages: -* `decorator` (for TVM) -* `enum34` (for NNVM) -* `Pillow` -* `wget` - -## Running the example - -Configure your environment with the following: -```bash -export VTA_PYNQ_RPC_HOST=192.168.2.99 -export VTA_PYNQ_RPC_PORT=9091 -``` +Follow the first two parts of the [Installation Guide](../../../docs/how_to/install.md) to make sure that the VTA python libraries are installed, and that the RPC server is running on the Pynq FPGA dev board. Simply run the following python script: ```bash diff --git a/vta/python/vta/environment.py b/vta/python/vta/environment.py index 7391d56e27744..3f0717f5b08b4 100644 --- a/vta/python/vta/environment.py +++ b/vta/python/vta/environment.py @@ -43,7 +43,6 @@ class DevContext(object): QID_LOAD_OUT = 2 QID_STORE_OUT = 3 QID_COMPUTE = 2 - QID_STORE_INP = 3 def __init__(self, env): self.vta_axis = tvm.thread_axis("vta") @@ -157,6 +156,7 @@ def __init__(self, cfg): self.acc_dtype = "int%d" % self.ACC_WIDTH self.inp_dtype = "int%d" % self.INP_WIDTH self.wgt_dtype = "int%d" % self.WGT_WIDTH + self.out_dtype = "int%d" % self.OUT_WIDTH # lazy cached members self.mock_mode = False self._mock_env = None diff --git a/vta/tutorials/get_started.py b/vta/tutorials/get_started.py index 73839196145d7..d3f3348754693 100644 --- a/vta/tutorials/get_started.py +++ b/vta/tutorials/get_started.py @@ -1,39 +1,387 @@ """ +.. _get-started: + Get Started with VTA ==================== -**Author**: `Tianqi Chen `_ +**Author**: `Thierry Moreau `_ -This is an introduction tutorial to on how to use TVM to program VTA +This is an introduction tutorial on how to use TVM to program the VTA design. -In this tutorial, we will demonstrate the basic workflow of VTA -and how we can program the FPGA to run various instructions. +In this tutorial, we will demonstrate the basic TVM workflow to implement +a vector addition on the VTA design's vector ALU. +This process includes specific scheduling transformations necessary to lower +computation down to low-level accelerator operations. -To begin with, we need to import tvm which is our compiler stack for VTA. -We also need to import vta python package which contains VTA specific -extensions for compiler to generate code that runs on VTA. +To begin, we need to import TVM which is our deep learning optimizing compiler. +We also need to import the VTA python package which contains VTA specific +extensions for TVM to target the VTA design. """ from __future__ import absolute_import, print_function +import os +import tvm import vta +import numpy as np + +###################################################################### +# Loading in VTA Parameters +# ~~~~~~~~~~~~~~~~~~~~~~~~~ +# VTA is a modular and customizable design. Consequently, the user +# is free to modify high-level hardware parameters that affect +# the hardware design layout. +# These parameters are specified in the :code:`config.json` file by their +# :code:`log2` values. +# These VTA parameters can be loaded with the :code:`vta.get_env` +# function. +# +# Finally, the TVM target is specified in the :code:`config.json` file. +# When set to *sim*, execution will take place inside of a behavioral +# VTA simulator. +# If you want to run this tutorial on the Pynq FPGA development platform, +# follow the *VTA Pynq-Based Testing Setup* guide. + +env = vta.get_env() + +###################################################################### +# FPGA Programming +# ---------------- +# When targeting the Pynq FPGA development board, we need to configure +# the board with a VTA bitstream. + +# We'll need the TVM RPC module and the VTA simulator module +from tvm.contrib import rpc, util +from vta.testing import simulator + +# We read the Pynq RPC host IP address and port number from the OS environment +host = os.environ.get("VTA_PYNQ_RPC_HOST", "192.168.2.99") +port = int(os.environ.get("VTA_PYNQ_RPC_PORT", "9091")) + +# We configure both the bitstream and the runtime system on the Pynq +# to match the VTA configuration specified by the config.json file. +if env.TARGET == "pynq": + + # Make sure that TVM was compiled with RPC=1 + assert tvm.module.enabled("rpc") + remote = rpc.connect(host, port) + + # Reconfigure the JIT runtime + vta.reconfig_runtime(remote) + + # Program the FPGA with a pre-compiled VTA bitstream. + # You can program the FPGA with your own custom bitstream + # by passing the path to the bitstream file instead of None. + vta.program_fpga(remote, bitstream=None) + +# In simulation mode, host the RPC server locally. +elif env.TARGET == "sim": + remote = rpc.LocalSession() + +###################################################################### +# Computation Declaration +# ----------------------- +# As a first step, we need to describe our computation. +# TVM adopts tensor semantics, with each intermediate result +# represented as multi-dimensional array. The user needs to describe +# the computation rule that generates the output tensors. +# +# In this example we describe a vector addition, which requires multiple +# computation stages, as shown in the dataflow diagram below. +# First we describe the input tensors :code:`A` and :code:`B` that are living +# in main memory. +# Second, we need to declare intermediate tensors :code:`A_buf` and +# :code:`B_buf`, which will live in VTA's on-chip buffers. +# Having this extra computational stage allows us to explicitly +# stage cached reads and writes. +# Third, we describe the vector addition computation which will +# add :code:`A_buf` to :code:`B_buf` to produce :code:`C_buf`. +# The last operation is a cast and copy back to DRAM, into results tensor +# :code:`C`. +# +# .. image:: https://raw.githubusercontent.com/uwsaml/web-data/master/vta/tutorial/vadd_dataflow.png +# :align: center + +###################################################################### +# Input Placeholders +# ~~~~~~~~~~~~~~~~~~ +# We describe the placeholder tensors :code:`A`, and :code:`B` in a tiled data +# format to match the data layout requirements imposed by the VTA vector ALU. +# +# For VTA's general purpose operations such as vector adds, the tile size is +# :code:`(env.BATCH, env.BLOCK_OUT)`. +# The dimensions are specified in +# the :code:`config.json` configuration file and are set by default to +# a (1, 16) vector. +# +# In addition, A and B's data types also needs to match the :code:`env.acc_dtype` +# which is set by the :code:`config.json` file to be a 32-bit integer. + +# Output channel factor m - total 64 x 16 = 1024 output channels +m = 64 +# Batch factor o - total 1 x 1 = 1 +o = 1 +# A placeholder tensor in tiled data format +A = tvm.placeholder((o, m, env.BATCH, env.BLOCK_OUT), name="A", dtype=env.acc_dtype) +# B placeholder tensor in tiled data format +B = tvm.placeholder((o, m, env.BATCH, env.BLOCK_OUT), name="B", dtype=env.acc_dtype) + +###################################################################### +# Copy Buffers +# ~~~~~~~~~~~~ +# One specificity of hardware accelerators, is that on-chip memory has to be +# explicitly managed. +# This means that we'll need to describe intermediate tensors :code:`A_buf` +# and :code:`B_buf` that can have a different memory scope than the original +# placeholder tensors :code:`A` and :code:`B`. +# +# Later in the scheduling phase, we can tell the compiler that :code:`A_buf` +# and :code:`B_buf` will live in the VTA's on-chip buffers (SRAM), while +# :code:`A` and :code:`B` will live in main memory (DRAM). +# We describe A_buf and B_buf as the result of a compute +# operation that is the identity function. +# This can later be interpreted by the compiler as a cached read operation. + +# A copy buffer +A_buf = tvm.compute((o, m, env.BATCH, env.BLOCK_OUT), lambda *i: A(*i), "A_buf") +# B copy buffer +B_buf = tvm.compute((o, m, env.BATCH, env.BLOCK_OUT), lambda *i: B(*i), "B_buf") + +###################################################################### +# Vector Addition +# ~~~~~~~~~~~~~~~ +# Now we're ready to describe the vector addition result tensor :code:`C`, +# with another compute operation. +# The compute function takes the shape of the tensor, as well as a lambda +# function that describes the computation rule for each position of the tensor. +# +# No computation happens during this phase, as we are only declaring how +# the computation should be done. + +# Describe the in-VTA vector addition +C_buf = tvm.compute( + (o, m, env.BATCH, env.BLOCK_OUT), + lambda *i: A_buf(*i).astype(env.acc_dtype) + B_buf(*i).astype(env.acc_dtype), + name="C_buf") + +###################################################################### +# Casting the Results +# ~~~~~~~~~~~~~~~~~~~ +# After the computation is done, we'll need to send the results computed by VTA +# back to main memory. + +###################################################################### +# .. note:: +# +# **Memory Store Restrictions** +# +# One specificity of VTA is that it only supports DRAM stores in the narrow +# :code:`env.inp_dtype` data type format. +# This lets us reduce the data footprint for memory transfers (more on this +# in the basic matrix multiply example). +# +# We perform one last typecast operation to the narrow +# input activation data format. + +# Cast to output type, and send to main memory +C = tvm.compute( + (o, m, env.BATCH, env.BLOCK_OUT), + lambda *i: C_buf(*i).astype(env.inp_dtype), + name="C") + +###################################################################### +# This concludes the computation declaration part of this tutorial. ###################################################################### -# Program the FPGA with VTA Bistream -# ---------------------------------- -# In the first step, we need to program the FPGA with VTA bitstream. +# Scheduling the Computation +# -------------------------- +# While the above lines describes the computation rule, we can obtain +# :code:`C` in many ways. +# TVM asks the user to provide an implementation of the computation called +# *schedule*. # +# A schedule is a set of transformations to an original computation that +# transforms the implementation of the computation without affecting +# correctness. +# This simple VTA programming tutorial aims to demonstrate basic schedule +# transformations that will map the original schedule down to VTA hardware +# primitives. + + +###################################################################### +# Default Schedule +# ~~~~~~~~~~~~~~~~ +# After we construct the schedule, by default the schedule computes +# :code:`C` in the following way: + +# Let's take a look at the generated schedule +s = tvm.create_schedule(C.op) + +print(tvm.lower(s, [A, B, C], simple_mode=True)) ###################################################################### -# Run Simple Copy Instruction -# --------------------------- +# Although this schedule makes sense, it won't compile to VTA. +# In order to obtain correct code generation, we need to apply scheduling +# primitives and code annotation that will transform the schedule into +# one that can be directly lowered onto VTA hardware intrinsics. +# Those include: # +# - DMA copy operations which will take globally-scoped tensors and copy +# those into locally-scoped tensors. +# - Vector ALU operations that will perform the vector add. ###################################################################### -# Run Matrix Instruction -# ---------------------- +# Buffer Scopes +# ~~~~~~~~~~~~~ +# First, we set the scope of the copy buffers to indicate to TVM that these +# intermediate tensors will be stored in the VTA's on-chip SRAM buffers. +# Below, we tell TVM that :code:`A_buf`, :code:`B_buf`, :code:`C_buf` +# will live in VTA's on-chip *accumulator buffer* which serves as +# VTA's general purpose register file. # +# Set the intermediate tensors' scope to VTA's on-chip accumulator buffer +s[A_buf].set_scope(env.acc_scope) +s[B_buf].set_scope(env.acc_scope) +s[C_buf].set_scope(env.acc_scope) ###################################################################### -# Matrix Multiplication Example -# ----------------------------- +# DMA Transfers +# ~~~~~~~~~~~~~ +# We need to schedule DMA transfers to move data living in DRAM to +# and from the VTA on-chip buffers. +# We insert :code:`dma_copy` pragmas to indicate to the compiler +# that the copy operations will be performed in bulk via DMA, +# which is common in hardware accelerators. + +# Tag the buffer copies with the DMA pragma to map a copy loop to a +# DMA transfer operation +s[A_buf].pragma(s[A_buf].op.axis[0], env.dma_copy) +s[B_buf].pragma(s[B_buf].op.axis[0], env.dma_copy) +s[C].pragma(s[C].op.axis[0], env.dma_copy) + +###################################################################### +# ALU Operations +# ~~~~~~~~~~~~~~ +# VTA has a vector ALU that can perform vector operations on tensors +# in the accumulator buffer. +# In order to tell TVM that a given operation needs to be mapped to the +# VTA's vector ALU, we need to explicitly tag the vector addition loop +# with an :code:`env.alu` pragma. + +# Tell TVM that the computation needs to be performed +# on VTA's vector ALU +s[C_buf].pragma(C_buf.op.axis[0], env.alu) + +# Let's take a look at the finalized schedule +print(vta.lower(s, [A, B, C], simple_mode=True)) + +###################################################################### +# This concludes the scheduling portion of this tutorial. + +###################################################################### +# TVM Compilation +# --------------- +# After we have finished specifying the schedule, we can compile it +# into a TVM function. By default TVM compiles into a type-erased +# function that can be directly called from python side. # +# In the following line, we use :code:`tvm.build` to create a function. +# The build function takes the schedule, the desired signature of the +# function(including the inputs and outputs) as well as target language +# we want to compile to. +# +my_vadd = vta.build(s, [A, B, C], "ext_dev", env.target_host, name="my_vadd") + +###################################################################### +# Saving the Module +# ~~~~~~~~~~~~~~~~~ +# TVM lets us save our module into a file so it can loaded back later. This +# is called ahead-of-time compilation and allows us to save some compilation +# time. +# More importantly, this allows us to cross-compile the executable on our +# development machine and send it over to the Pynq FPGA board over RPC for +# execution. + +# Write the compiled module into an object file. +temp = util.tempdir() +my_vadd.save(temp.relpath("vadd.o")) + +# Send the executable over RPC +remote.upload(temp.relpath("vadd.o")) + +###################################################################### +# Loading the Module +# ~~~~~~~~~~~~~~~~~~ +# We can load the compiled module from the file system to run the code. + +f = remote.load_module("vadd.o") + +###################################################################### +# Running the Function +# -------------------- +# The compiled TVM function uses a concise C API and can be invoked from +# any language. +# +# TVM provides an array API in python to aid quick testing and prototyping. +# The array API is based on `DLPack `_ standard. +# +# - We first create a remote context (for remote execution on the Pynq). +# - Then :code:`tvm.nd.array` formats the data accordingly. +# - :code:`f()` runs the actual computation. +# - :code:`asnumpy()` copies the result array back in a format that can be +# interpreted. +# + +# Get the remote device context +ctx = remote.ext_dev(0) + +# Initialize the A and B arrays randomly in the int range of (-128, 128] +A_orig = np.random.randint( + -128, 128, size=(o * env.BATCH, m * env.BLOCK_OUT)).astype(A.dtype) +B_orig = np.random.randint( + -128, 128, size=(o * env.BATCH, m * env.BLOCK_OUT)).astype(B.dtype) + +# Apply packing to the A and B arrays from a 2D to a 4D packed layout +A_packed = A_orig.reshape( + o, env.BATCH, m, env.BLOCK_OUT).transpose((0, 2, 1, 3)) +B_packed = B_orig.reshape( + o, env.BATCH, m, env.BLOCK_OUT).transpose((0, 2, 1, 3)) + +# Format the input/output arrays with tvm.nd.array to the DLPack standard +A_nd = tvm.nd.array(A_packed, ctx) +B_nd = tvm.nd.array(B_packed, ctx) +C_nd = tvm.nd.array(np.zeros((o, m, env.BATCH, env.BLOCK_OUT)).astype(C.dtype), ctx) + +# Invoke the module to perform the computation +f(A_nd, B_nd, C_nd) + +###################################################################### +# Verifying Correctness +# --------------------- +# Compute the reference result with numpy and assert that the output of the +# matrix multiplication indeed is correct + +# Compute reference result with numpy +C_ref = (A_orig.astype(env.acc_dtype) + B_orig.astype(env.acc_dtype)).astype(C.dtype) +C_ref = C_ref.reshape( + o, env.BATCH, m, env.BLOCK_OUT).transpose((0, 2, 1, 3)) +np.testing.assert_equal(C_ref, C_nd.asnumpy()) +print("Successful vector add test!") + +###################################################################### +# Summary +# ------- +# This tutorial provides a walk-through of TVM for programming the +# deep learning accelerator VTA with a simple vector addition example. +# The general workflow includes: +# +# - Programming the FPGA with the VTA bitstream over RPC. +# - Describing the vector add computation via a series of computations. +# - Describing how we want to perform the computation using schedule primitives. +# - Compiling the function to the VTA target. +# - Running the compiled module and verifying it against a numpy implementation. +# +# You are more than welcome to check other examples out and tutorials +# to learn more about the supported operations, schedule primitives +# and other features supported by TVM to program VTA. +# + diff --git a/vta/tutorials/matrix_multiply.py b/vta/tutorials/matrix_multiply.py new file mode 100644 index 0000000000000..f9d4d9fc7094b --- /dev/null +++ b/vta/tutorials/matrix_multiply.py @@ -0,0 +1,453 @@ +""" +.. _basic-mat-mult: + +Simple Matrix Multiply +====================== +**Author**: `Thierry Moreau `_ + +In this tutorial, we will build on top of the :ref:`get-started` tutorial +and introduce additional concepts required to implement matrix multiplication +on VTA with the TVM workflow. +""" + +###################################################################### +# RPC Setup +# --------- +# We start by programming the Pynq's FPGA and building its RPC runtime +# as we did in the VTA introductory tutorial. + +from __future__ import absolute_import, print_function + +import os +import tvm +import vta +import numpy as np +from tvm.contrib import rpc, util +from vta.testing import simulator + +# Load VTA parameters from the config.json file +env = vta.get_env() + +# We read the Pynq RPC host IP address and port number from the OS environment +host = os.environ.get("VTA_PYNQ_RPC_HOST", "192.168.2.99") +port = int(os.environ.get("VTA_PYNQ_RPC_PORT", "9091")) + +# We configure both the bitstream and the runtime system on the Pynq +# to match the VTA configuration specified by the config.json file. +if env.TARGET == "pynq": + + # Make sure that TVM was compiled with RPC=1 + assert tvm.module.enabled("rpc") + remote = rpc.connect(host, port) + + # Reconfigure the JIT runtime + vta.reconfig_runtime(remote) + + # Program the FPGA with a pre-compiled VTA bitstream. + # You can program the FPGA with your own custom bitstream + # by passing the path to the bitstream file instead of None. + vta.program_fpga(remote, bitstream=None) + +# In simulation mode, host the RPC server locally. +elif env.TARGET == "sim": + remote = rpc.LocalSession() + +###################################################################### +# Computation Declaration +# ----------------------- +# In this example we describe a simple matrix multiplication addition, which +# requires multiple computation stages, as shown in the dataflow diagram below. +# First we describe the input tensors :code:`A` and :code:`B` that are living +# in main memory. +# Second, we need to declare intermediate tensors :code:`A_buf` and +# :code:`B_buf`, which will live in VTA's on-chip buffers. +# Having this extra computational stage allows us to explicitly +# stage cached reads and writes. +# Third, we describe the matrix multiplication computation over +# :code:`A_buf` and :code:`B_buf` to produce the product matrix :code:`C_buf`. +# The last operation is a cast and copy back to DRAM, into results tensor +# :code:`C`. +# +# .. image:: https://raw.githubusercontent.com/uwsaml/web-data/master/vta/tutorial/gemm_dataflow.png +# :align: center + +###################################################################### +# Data Layout +# ~~~~~~~~~~~ +# We describe the placeholder tensors :code:`A`, and :code:`B` in a tiled data +# format to match the data layout requirements imposed by the VTA tensor core. + +###################################################################### +# .. note:: +# +# **Data Tiling** +# +# One source of complexity when targeting accelerators is to make sure +# that the data layout matches the layout imposed by the accelerator design. +# VTA is designed around a *tensor core* that performs, one matrix-matrix +# operation per cycle between an activation matrix and a weight matrix, +# adding the result matrix to an accumulator matrix, as shown in the +# figure below. +# +# .. image:: https://raw.githubusercontent.com/uwsaml/web-data/master/vta/tutorial/tensor_core.png +# :align: center +# +# The dimensions of that matrix-matrix multiplication are specified in +# the :code:`config.json` configuration file. +# The activation matrix has a :code:`(BATCH, BLOCK_IN)` shape +# and the transposed weight matrix has a :code:`(BLOCK_OUT, BLOCK_IN)` shape, +# thus inferring that the resulting output matrix has a +# :code:`(BATCH, BLOCK_OUT)` shape. +# Consequently input and output tensors processed by VTA need to be +# tiled according to these aforementioned dimension. +# +# The diagram below shows the impact of data tiling on a matrix that is +# originally of shape (4, 8). +# Tiling by a (2, 2) tile shape ensures that data within each tile is +# contiguous. +# The resulting tiled tensor has a shape of (2, 4, 2, 2). +# +# .. image:: https://raw.githubusercontent.com/uwsaml/web-data/master/vta/tutorial/data_tiling.png +# :align: center +# +# We first define the variables :code:`m`, :code:`n`, :code:`o` to represent +# the shape of the matrix multiplication. These variables are multiplicative +# factors over the :code:`BLOCK_OUT`, :code:`BLOCK_IN`, and :code:`BATCH` +# tensor dimensions respectively. +# By default, the configuration file sets :code:`BATCH`, :code:`BLOCK_IN`, and +# :code:`BLOCK_OUT` to be 1, 16 and 16 respectively (:code:`BATCH` being set to +# 1 implies that our compute building block is vector-matrix multiply). +# + + +###################################################################### +# .. note:: +# +# **Data Types** +# +# It's important to not only match the inner-tile +# dimension of VTA's tensor core, but also to match the specific data types +# expected by VTA. +# VTA for now only supports fixed point data types, which integer width is +# specified in the :code:`config.json` file by :code:`INP_WIDTH` and +# :code:`WGT_WIDTH` for the activations and weights data types respectively. +# In addition, the accumulator data type integer width is specified by +# :code:`ACC_WIDTH`. +# +# By default, the configuration file sets :code:`INP_WIDTH` +# and :code:`WGT_WIDTH` to 8. +# The accumulator width :code:`ACC_WIDTH` is set to 32, in order to avoid +# overflow during accumulation. +# As a result, :code:`env.inp_dtype` and :code:`env.wgt_dtype` are all +# narrow 8-bit integers, while :code:`env.acc_dtype` is a standard 32-bit +# integer. + +# Output channel factor m - total 16x16=256 output channels +m = 16 +# Input channel factor n - total 16x16=256 input channels +n = 16 +# Batch factor o (we use single batch inference) +o = 1 +# A placeholder tensor in tiled data format +A = tvm.placeholder((o, n, env.BATCH, env.BLOCK_IN), name="A", dtype=env.inp_dtype) +# B placeholder tensor in tiled data format +B = tvm.placeholder((m, n, env.BLOCK_OUT, env.BLOCK_IN), name="B", dtype=env.wgt_dtype) +# A copy buffer +A_buf = tvm.compute((o, n, env.BATCH, env.BLOCK_IN), lambda *i: A(*i), "A_buf") +# B copy buffer +B_buf = tvm.compute((m, n, env.BLOCK_OUT, env.BLOCK_IN), lambda *i: B(*i), "B_buf") + +###################################################################### +# Matrix Multiplication +# ~~~~~~~~~~~~~~~~~~~~~ +# Now we're ready to describe the matrix multiplication result tensor :code:`C`, +# with another compute operation. +# The compute function takes the shape of the tensor, as well as a lambda +# function that describes the computation rule for each position of the tensor. +# +# In order to implement matrix multiplication, the lambda function needs to +# include a reduction formula over the input channel dimension axes. +# To create a reduction formula, we can declare a reduction axis using +# :code:`tvm.reduce_axis`, which takes in the range of reductions. +# :code:`tvm.sum` takes in the expression to be reduced as well as +# the reduction axes to compute the sum of value over all k in the declared +# ranges. +# +# Note that the reduction needs to be performed over 32-bit :code:`env.acc_dtype` +# accumulator data types. +# +# No computation happens during this phase, as we are only declaring how +# the computation should be done. + +# Outer input feature reduction axis +ko = tvm.reduce_axis((0, n), name="ko") +# Inner input feature reduction axis +ki = tvm.reduce_axis((0, env.BLOCK_IN), name="ki") +# Describe the in-VTA matrix multiplication +C_buf = tvm.compute( + (o, m, env.BATCH, env.BLOCK_OUT), + lambda bo, co, bi, ci: + tvm.sum(A_buf[bo, ko, bi, ki].astype(env.acc_dtype) * + B_buf[co, ko, ci, ki].astype(env.acc_dtype), + axis=[ko, ki]), + name="C_buf") + +###################################################################### +# Casting the Results +# ~~~~~~~~~~~~~~~~~~~ +# After the computation is done, we'll need to send the results computed by VTA +# back to main memory. + +###################################################################### +# .. note:: +# +# **Memory Store Restrictions** +# +# One specificity of VTA is that it only supports DRAM stores in the narrow +# :code:`env.inp_dtype` data type format. +# This lets us reduce the data footprint for memory transfers, but also lets +# us quantize the wide accumulator data type down to a data format that +# matches the input activation data type. +# This means that in the context of neural network inference, the outputs +# of a given layer after activation can be consumed directly by the next +# layer. +# +# We perform one last typecast operation to the narrow +# input activation data format. + +# Cast to output type, and send to main memory +C = tvm.compute( + (o, m, env.BATCH, env.BLOCK_OUT), + lambda *i: C_buf(*i).astype(env.inp_dtype), + name="C") + +###################################################################### +# This concludes the computation declaration part of this tutorial. + +###################################################################### +# Scheduling the Computation +# -------------------------- +# While the above lines describes the computation rule, we can obtain +# :code:`C` in many ways. +# TVM asks the user to provide an implementation of the computation called +# *schedule*. +# +# A schedule is a set of transformations to an original computation that +# transforms the implementation of the computation without affecting +# correctness. +# This simple VTA programming tutorial aims to demonstrate basic schedule +# transformations that will map the original schedule down to VTA hardware +# primitives. + + +###################################################################### +# Default Schedule +# ~~~~~~~~~~~~~~~~ +# After we construct the schedule, by default the schedule computes +# :code:`C` in the following way: + +# Let's take a look at the generated schedule +s = tvm.create_schedule(C.op) +print(tvm.lower(s, [A, B, C], simple_mode=True)) + +###################################################################### +# Although this schedule makes sense, it won't compile to VTA. +# In order to obtain correct code generation, we need to apply scheduling +# primitives and code annotation that will transform the schedule into +# one that can be directly lowered onto VTA hardware intrinsics. +# Those include: +# +# - DMA copy operations which will take globally-scoped tensors and copy +# those into locally-scoped tensors. +# - Tensor operations that will perform the matrix multiplication. + +###################################################################### +# Buffer Scopes +# ~~~~~~~~~~~~~ +# First, we set the scope of the buffers to tell TVM that these buffers +# will be living in the VTA's on-chip SRAM caches. +# Below, we tell TVM that :code:`A_buf`, :code:`B_buf`, :code:`C_buf` +# will respectively live in VTA's on-chip input, weight and accumulator +# memory. + +###################################################################### +# .. note:: +# +# **VTA's On-Chip SRAMs** +# +# VTA has three different memory scopes, each corresponding to different +# on-chip SRAM buffers. +# +# - :code:`env.inp_scope`: Input buffer, which is a read-only SRAM buffer +# that stores input matrices of shape :code:`(env.BATCH, env.BLOCK_IN)` +# of type :code:`env.inp_dtype`. The input buffer contains +# `2 ^ LOG_INP_BUFF_SIZE` matrix elements (as specified in the +# :code:`config.json` file). +# - :code:`env.wgt_scope`: Weight buffer, which is a read-only SRAM buffer +# that stores weight matrices of shape :code:`(env.BLOCK_OUT, env.BLOCK_IN)` +# of type :code:`env.wgt_dtype`. The weight buffer contains +# `2 ^ LOG_WGT_BUFF_SIZE` matrix elements. +# - :code:`env.acc_scope`: Accumulator buffer, which is a read/write SRAM +# buffer that stores accumulator matrices of shape +# :code:`(env.BATCH, env.BLOCK_OUT)` of type :code:`env.acc_dtype`. +# The accumulator buffer is VTA's general purpose register file: it holds +# both intermediate results of convolutions and matrix multiplications +# as well as intermediate results of pooling, batch normalization, and +# activation layers. The accumulator buffer contains +# `2 ^ LOG_ACC_BUFF_SIZE` matrix elements. + +# Set the intermediate tensor's scope to VTA's on-chip buffers +s[A_buf].set_scope(env.inp_scope) +s[B_buf].set_scope(env.wgt_scope) +s[C_buf].set_scope(env.acc_scope) + +###################################################################### +# DMA Transfers +# ~~~~~~~~~~~~~ +# We need to schedule DMA transfers to move data living in DRAM to +# and from the VTA on-chip buffers. +# This can be achieved using the :code:`compute_at` schedule primitive +# which nests the copying of the buffers into the computation loop +# that performs the matrix multiplication. +# +# We insert :code:`dma_copy` pragmas to indicate to the compiler +# that the copy operations will be performed in bulk via DMA, +# which is common in hardware accelerators. +# Finally, we print the temporary schedule to observe the effects of +# moving the copy operations into the matrix multiplication loop. + +# Move buffer copy into matrix multiply loop +s[A_buf].compute_at(s[C_buf], ko) +s[B_buf].compute_at(s[C_buf], ko) + +# Tag the buffer copies with the DMA pragma to insert a DMA transfer +s[A_buf].pragma(s[A_buf].op.axis[0], env.dma_copy) +s[B_buf].pragma(s[B_buf].op.axis[0], env.dma_copy) +s[C].pragma(s[C].op.axis[0], env.dma_copy) + +# Let's take a look at the transformed schedule +print(tvm.lower(s, [A, B, C], simple_mode=True)) + +###################################################################### +# Tensorization +# ~~~~~~~~~~~~~ +# The last step of the schedule transformation consists in applying +# *tensorization* to our schedule. +# Tensorization is analogous to vectorization, but extends the concept +# to a higher-dimensional unit of computation. +# Consequently, tensorization imposes data layout constraints as discussed +# earlier when declaring the data layout input placeholders. +# We've already arranged our tensors in a tiled format, so the next thing +# we need to perform is loop reordering to accommodate for tensorization. +# +# Here we choose to move the outermost reduction axis all the way out. +# This dictates that we first iterate over input channels, then batch +# dimensions, and finally output channels. +# Lastly, we apply the tensorization scheduling primitive :code:`tensorize` +# along the outer axis of the inner-most matrix matrix multiplication tensor +# block. +# We print the finalized schedule that is ready for code-generation +# by the VTA runtime JIT compiler. + +s[C_buf].reorder( + ko, + s[C_buf].op.axis[0], + s[C_buf].op.axis[1], + s[C_buf].op.axis[2], + s[C_buf].op.axis[3], + ki) +s[C_buf].tensorize(s[C_buf].op.axis[2], env.gemm) + +# Let's take a look at the finalized schedule +print(vta.lower(s, [A, B, C], simple_mode=True)) + +###################################################################### +# This concludes the scheduling portion of this tutorial. + +###################################################################### +# TVM Compilation +# --------------- +# After we have finished specifying the schedule, we can compile it +# into a TVM function. + +# Build GEMM VTA kernel +my_gemm = vta.build(s, [A, B, C], "ext_dev", env.target_host, name="my_gemm") + +# Write the compiled module into an object file. +temp = util.tempdir() +my_gemm.save(temp.relpath("gemm.o")) + +# Send the executable over RPC +remote.upload(temp.relpath("gemm.o")) + +# Load the compiled module +f = remote.load_module("gemm.o") + +###################################################################### +# Running the Function +# -------------------- +# The compiled TVM function uses a concise C API and can be invoked from +# code language. +# +# TVM provides an array API in python to aid quick testing and prototyping. +# The array API is based on `DLPack `_ standard. +# +# - We first create a remote context (for remote execution on the Pynq). +# - Then :code:`tvm.nd.array` formats the data accordingly. +# - :code:`f()` runs the actual computation. +# - :code:`asnumpy()` copies the result array back in a format that can be +# interpreted. +# + +# Get the remote device context +ctx = remote.ext_dev(0) + +# Initialize the A and B arrays randomly in the int range of (-128, 128] +A_orig = np.random.randint( + -128, 128, size=(o * env.BATCH, n * env.BLOCK_IN)).astype(A.dtype) +B_orig = np.random.randint( + -128, 128, size=(m * env.BLOCK_OUT, n * env.BLOCK_IN)).astype(B.dtype) + +# Apply packing to the A and B arrays from a 2D to a 4D packed layout +A_packed = A_orig.reshape( + o, env.BATCH, n, env.BLOCK_IN).transpose((0, 2, 1, 3)) +B_packed = B_orig.reshape( + m, env.BLOCK_OUT, n, env.BLOCK_IN).transpose((0, 2, 1, 3)) + +# Format the input/output arrays with tvm.nd.array to the DLPack standard +A_nd = tvm.nd.array(A_packed, ctx) +B_nd = tvm.nd.array(B_packed, ctx) +C_nd = tvm.nd.array(np.zeros((o, m, env.BATCH, env.BLOCK_OUT)).astype(C.dtype), ctx) + +# Invoke the module to perform the computation +f(A_nd, B_nd, C_nd) + +###################################################################### +# Verifying Correctness +# --------------------- +# Compute the reference result with numpy and assert that the output of the +# matrix multiplication indeed is correct + +# Compute reference result with numpy +C_ref = np.dot(A_orig.astype(env.acc_dtype), + B_orig.T.astype(env.acc_dtype)).astype(C.dtype) +C_ref = C_ref.reshape( + o, env.BATCH, m, env.BLOCK_OUT).transpose((0, 2, 1, 3)) +np.testing.assert_equal(C_ref, C_nd.asnumpy()) +print("Successful matrix multiply test!") + + +###################################################################### +# Summary +# ------- +# This tutorial showcases the TVM workflow to implement a simple matrix +# multiplication example on VTA. +# The general workflow includes: +# +# - Programming the FPGA with the VTA bitstream over RPC. +# - Describing matrix multiplication via a series of computations. +# - Describing how we want to perform the computation using schedule primitives. +# - Compiling the function to the VTA target. +# - Running the compiled module and verifying it against a numpy implementation. +# + diff --git a/vta/tutorials/matrix_multiply_opt.py b/vta/tutorials/matrix_multiply_opt.py new file mode 100644 index 0000000000000..9b62504a70078 --- /dev/null +++ b/vta/tutorials/matrix_multiply_opt.py @@ -0,0 +1,362 @@ +""" +.. _mat-mult-opt: + +Matrix Multiply Blocking +======================== +**Author**: `Thierry Moreau `_ + +This tutorial provides an overview on how to use TVM to map matrix +multiplication efficiently on the VTA design. +We recommend covering the :ref:`basic-mat-mult` tutorial first. + +In this tutorial, we will demonstrate TVM schedule optimizations to break large +neural network operators down onto smaller blocks to achieve computation within +limited hardware accelerator resources. +""" + +###################################################################### +# RPC Setup +# --------- +# We start by programming the Pynq's FPGA and building its RPC runtime. + +from __future__ import absolute_import, print_function + +import os +import tvm +import vta +import numpy as np +from tvm.contrib import rpc, util +from vta.testing import simulator + +# Load VTA parameters from the config.json file +env = vta.get_env() + +# We read the Pynq RPC host IP address and port number from the OS environment +host = os.environ.get("VTA_PYNQ_RPC_HOST", "192.168.2.99") +port = int(os.environ.get("VTA_PYNQ_RPC_PORT", "9091")) + +# We configure both the bitstream and the runtime system on the Pynq +# to match the VTA configuration specified by the config.json file. +if env.TARGET == "pynq": + + # Make sure that TVM was compiled with RPC=1 + assert tvm.module.enabled("rpc") + remote = rpc.connect(host, port) + + # Reconfigure the JIT runtime + vta.reconfig_runtime(remote) + + # Program the FPGA with a pre-compiled VTA bitstream. + # You can program the FPGA with your own custom bitstream + # by passing the path to the bitstream file instead of None. + vta.program_fpga(remote, bitstream=None) + +# In simulation mode, host the RPC server locally. +elif env.TARGET == "sim": + remote = rpc.LocalSession() + +###################################################################### +# Computation Declaration +# ----------------------- +# As a first step, we need to describe our matrix multiplication computation. +# We define the matrix multiplication as the computation one would find in a +# fully connected layer, defined by its batch size, input channels, and output +# channels. +# These have to be integer multiples of the VTA tensor shape: +# :code:`BATCH`, :code:`BLOCK_IN`, and :code:`BLOCK_OUT` respectively. +# +# We've added extra operators to the matrix multiplication that apply +# shifting and clipping to the output in order to mimic the a fixed-point +# matrix multiplication followed by a rectified linear activation. +# We describe the TVM dataflow graph of the fully connected layer below: +# +# .. image:: https://raw.githubusercontent.com/uwsaml/web-data/master/vta/tutorial/fc_dataflow.png +# :align: center +# +# This computation is intentionally too large to fit onto VTA's on-chip +# buffers all at once. Therefore in the scheduling phase we'll +# rely on computation blocking strategies to break the computation down into +# manageable chunks. + +# Fully connected layer dimensions: 1024 x 1024 +batch_size = 1 +in_channels = 1024 +out_channels = 1024 +assert batch_size % env.BATCH == 0 +assert in_channels % env.BLOCK_IN == 0 +assert out_channels % env.BLOCK_OUT == 0 + +# Let's derive the tiled input tensor shapes +data_shape = (batch_size // env.BATCH, + in_channels // env.BLOCK_IN, + env.BATCH, + env.BLOCK_IN) +weight_shape = (out_channels // env.BLOCK_OUT, + in_channels // env.BLOCK_IN, + env.BLOCK_OUT, + env.BLOCK_IN) +output_shape = (batch_size // env.BATCH, + out_channels // env.BLOCK_OUT, + env.BATCH, + env.BLOCK_OUT) +num_ops = in_channels * out_channels * batch_size * 2 + +# Reduction axes +ic = tvm.reduce_axis((0, in_channels // env.BLOCK_IN), name='ic') +ic_tns = tvm.reduce_axis((0, env.BLOCK_IN), name='ic_tns') + +# Input placeholder tensors +data = tvm.placeholder(data_shape, name="data", dtype=env.inp_dtype) +weight = tvm.placeholder(weight_shape, name="weight", dtype=env.wgt_dtype) + +# Copy buffers +data_buf = tvm.compute(data_shape, + lambda *i: data(*i), + "data_buf") +weight_buf = tvm.compute(weight_shape, + lambda *i: weight(*i), + "weight_buf") + +# Declare matrix multiply computation +res_gemm = tvm.compute(output_shape, + lambda bo, co, bi, ci: tvm.sum( + data_buf[bo, ic, bi, ic_tns].astype(env.acc_dtype) * + weight_buf[co, ic, ci, ic_tns].astype(env.acc_dtype), + axis=[ic, ic_tns]), + name="res_gem") + +# Add shift stage for fix-point normalization +res_shr = tvm.compute(output_shape, + lambda *i: res_gemm(*i) >> env.INP_WIDTH, + name="res_shr") + +# Apply clipping between (0, input max value) +inp_max = (1<<(env.INP_WIDTH-1))-1 +res_max = tvm.compute(output_shape, + lambda *i: tvm.max(res_shr(*i), 0), + "res_max") +res_min = tvm.compute(output_shape, + lambda *i: tvm.min(res_max(*i), inp_max), + "res_min") + +# Apply typecast to input data type before sending results back +res = tvm.compute(output_shape, + lambda *i: res_min(*i).astype(env.inp_dtype), + name="res") + +###################################################################### +# Scheduling the Computation +# -------------------------- +# We'll look at a set of schedule transformations necessary to map the +# matrix multiplications onto VTA in an efficient fashion. +# Those include: +# +# - Computation blocking +# - Computation lowering to VTA hardware intrinsics + + +# Create TVM schedule +s = tvm.create_schedule(res.op) +# Let's look at the default TVM schedule +print(tvm.lower(s, [data, weight, res], simple_mode=True)) + +###################################################################### +# Tiling the Computation +# ~~~~~~~~~~~~~~~~~~~~~~ +# The matrix multiplication is by default too large for activations or weights +# to fit on VTA's on-chip buffers all at once. +# We block the (1, 1024) by (1024, 1024) matrix multiplication into +# smaller (1, 256) by (256, 256) matrix multiplications so the intermediate +# tensors can fit on the accelerator's on-chip SRAM. +# This approach is similar to blocking techniques applied to CPUs and GPUs in +# order to increase cache hit rate. +# +# We perform blocking along each axes (the batch axis being untouched since +# we are performing singe-batch inference). +# We also leave the inner-most tensorization axes as-is in order to allow +# TVM to pattern-match tensorization. +# We show the outcome of blocking on the computation schedule in the diagram +# below: +# +# .. image:: https://raw.githubusercontent.com/uwsaml/web-data/master/vta/tutorial/blocking.png +# :align: center +# :height: 367px +# :width: 387px +# +# .. note:: +# +# The code after loop splitting and reordering is equivalent to the following +# pseudo-code. We ignore the batch axis since we are only performing single-batch +# inference in this example: +# +# .. code-block:: c +# +# for (int oc_out = 0; oc_out < 4; ++oc_out) { +# // Initialization loop +# for (int oc_inn = 0; oc_inn < 16; ++oc_inn) { +# for (int oc_tns = 0; oc_tns < 16; ++oc_tns) { +# int j = (oc_out * 16 + oc_inn) * 16 + oc_tns; +# C[0][j] = 0; +# } +# } +# for (int ic_out = 0; ic_out < 4; ++ic_out) { +# // Block loop +# for (int oc_inn = 0; oc_inn < 16; ++oc_inn) { +# for (int ic_inn = 0; ic_inn < 16; ++ic_inn) { +# // Tensorization loop +# for (int oc_tns = 0; oc_tns < 16; ++oc_tns) { +# for (int ic_tns = 0; ic_tns < 16; ++ic_tns) { +# int i = (ic_out * 16 + ic_inn) * 16 + ic_tns; +# int j = (oc_out * 16 + oc_inn) * 16 + oc_tns; +# C[0][i] = C[0][i] + A[0][i] * B[j][i]; +# } +# } +# } +# } +# } +# } +# } + +# Let's define tiling sizes (expressed in multiples of VTA tensor shape size) +b_block = 1 // env.BATCH +i_block = 256 // env.BLOCK_IN +o_block = 256 // env.BLOCK_OUT + +# Tile the output tensor along the batch and output channel dimensions +# (since by default we are doing single batch inference, the split along +# the batch dimension has no effect) +b, oc, b_tns, oc_tns = s[res].op.axis +b_out, b_inn = s[res].split(b, b_block) +oc_out, oc_inn = s[res].split(oc, o_block) +s[res].reorder(b_out, oc_out, b_inn, oc_inn) + +# Move intermediate computation into each output compute tile +s[res_gemm].compute_at(s[res], oc_out) +s[res_shr].compute_at(s[res], oc_out) +s[res_max].compute_at(s[res], oc_out) +s[res_min].compute_at(s[res], oc_out) + +# Apply additional loop split along input channel axis +b_inn, oc_inn, b_tns, oc_tns = s[res_gemm].op.axis +ic_out, ic_inn = s[res_gemm].split(ic, i_block) + +# Reorder axes. We move the ic_out axis all the way out of the GEMM +# loop to block along the reduction axis +s[res_gemm].reorder(ic_out, b_inn, oc_inn, ic_inn, b_tns, oc_tns, ic_tns) + +# Let's look at the current TVM schedule after blocking +print(tvm.lower(s, [data, weight, res], simple_mode=True)) + +###################################################################### +# Lowering Copies to DMA Transfers +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# Next we set the buffer scopes to the corresponding on-chip VTA SRAM buffers. +# We move the load loops into the matrix multiply computation loop to stage +# memory loads such that they fit in the on-chip SRAM buffers. +# Finally we annotate the load/store loop outer axes with the DMA copy pragma +# to perform bulk memory transfers on VTA. + +# Set scope of SRAM buffers +s[data_buf].set_scope(env.inp_scope) +s[weight_buf].set_scope(env.wgt_scope) +s[res_gemm].set_scope(env.acc_scope) +s[res_shr].set_scope(env.acc_scope) +s[res_min].set_scope(env.acc_scope) +s[res_max].set_scope(env.acc_scope) + +# Block data and weight cache reads +s[data_buf].compute_at(s[res_gemm], ic_out) +s[weight_buf].compute_at(s[res_gemm], ic_out) + +# Use DMA copy pragma on DRAM->SRAM operations +s[data_buf].pragma(s[data_buf].op.axis[0], env.dma_copy) +s[weight_buf].pragma(s[weight_buf].op.axis[0], env.dma_copy) + +# Use DMA copy pragma on SRAM->DRAM operation +s[res].pragma(s[res].op.axis[2], env.dma_copy) + +###################################################################### +# Lowering Computation to VTA Compute Intrinsics +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# The last phase is to lower the computation loops down to VTA hardware +# intrinsics by mapping the matrix multiplication to tensor intrinsics, +# and mapping the shift, and clipping computation to the vector ALU. + +# Apply tensorization over the batch tensor tile axis +s[res_gemm].tensorize(b_tns, env.gemm) + +# Add an ALU pragma over the shift and clipping operations +s[res_shr].pragma(s[res_shr].op.axis[0], env.alu) +s[res_min].pragma(s[res_min].op.axis[0], env.alu) +s[res_max].pragma(s[res_max].op.axis[0], env.alu) + +# Let's look at the final lowered TVM schedule after lowering memory +# loads/stores down to DMA copy intrinsics, and the computation down to +# VTA compute intrinsics. +print(vta.lower(s, [data, weight, res], simple_mode=True)) + +###################################################################### +# TVM Compilation and Verification +# -------------------------------- +# After specifying the schedule, we can compile it into a TVM function. +# We save the module so we can send it over RPC. +# We run the function and verify it against a numpy implementation to +# ensure correctness. + +# Compile the TVM module +my_gemm = vta.build(s, [data, weight, res], "ext_dev", env.target_host, name="my_gemm") +temp = util.tempdir() +my_gemm.save(temp.relpath("gemm.o")) +remote.upload(temp.relpath("gemm.o")) +f = remote.load_module("gemm.o") + +# Get the remote device context +ctx = remote.ext_dev(0) + +# Initialize the A and B arrays randomly in the int range of (-128, 128] +data = np.random.randint( + -128, 128, size=(batch_size, in_channels)).astype(data.dtype) +weight = np.random.randint( + -128, 128, size=(out_channels, in_channels)).astype(weight.dtype) + +# Apply packing to the A and B arrays from a 2D to a 4D packed layout +data_packed = data.reshape(batch_size // env.BATCH, + env.BATCH, + in_channels // env.BLOCK_IN, + env.BLOCK_IN).transpose((0, 2, 1, 3)) +weight_packed = weight.reshape(out_channels // env.BLOCK_OUT, + env.BLOCK_OUT, + in_channels // env.BLOCK_IN, + env.BLOCK_IN).transpose((0, 2, 1, 3)) + +# Format the input/output arrays with tvm.nd.array to the DLPack standard +data_nd = tvm.nd.array(data_packed, ctx) +weight_nd = tvm.nd.array(weight_packed, ctx) +res_nd = tvm.nd.array(np.zeros(output_shape).astype(res.dtype), ctx) + +# Invoke the module to perform the computation +f(data_nd, weight_nd, res_nd) + +# Verify against numpy implementation +res_ref = np.dot(data.astype(env.acc_dtype), + weight.T.astype(env.acc_dtype)) +res_ref = res_ref >> env.INP_WIDTH +res_ref = np.clip(res_ref, 0, inp_max) +res_ref = res_ref.astype(res.dtype) +res_ref = res_ref.reshape(batch_size // env.BATCH, + env.BATCH, + out_channels // env.BLOCK_OUT, + env.BLOCK_OUT).transpose((0, 2, 1, 3)) +np.testing.assert_equal(res_ref, res_nd.asnumpy()) +print("Successful blocked matrix multiply test!") + +###################################################################### +# Summary +# ------- +# This tutorial demonstrates how TVM scheduling primitives can achieve +# computation blocking for a matrix multiplication example. +# This allows us to map arbitrarily large computation onto limited +# hardware accelerator resources. +# + +