diff --git a/vta/python/vta/exec/rpc_server.py b/vta/python/vta/exec/rpc_server.py index 5ca5d596a007..8caa48a56104 100644 --- a/vta/python/vta/exec/rpc_server.py +++ b/vta/python/vta/exec/rpc_server.py @@ -28,7 +28,7 @@ import tvm from tvm import rpc from tvm.contrib import cc -from pynq import Bitstream +from vta import program_bitstream from ..environment import get_env from ..pkg_config import PkgConfig @@ -67,9 +67,9 @@ def ext_dev_callback(): @tvm.register_func("tvm.contrib.vta.init", override=True) def program_fpga(file_name): path = tvm.get_global_func("tvm.rpc.server.workpath")(file_name) - bitstream = Bitstream(path) - bitstream.download() - logging.info("Program FPGA with %s", file_name) + env = get_env() + program_bitstream.bitstream_program(env.TARGET, path) + logging.info("Program FPGA with %s ", file_name) @tvm.register_func("tvm.rpc.server.shutdown", override=True) def server_shutdown(): diff --git a/vta/python/vta/program_bitstream.py b/vta/python/vta/program_bitstream.py new file mode 100644 index 000000000000..5c5a86293885 --- /dev/null +++ b/vta/python/vta/program_bitstream.py @@ -0,0 +1,57 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""VTA specific bitstream program library.""" +import os +import argparse + +def main(): + """Main funciton""" + parser = argparse.ArgumentParser() + parser.add_argument("target", type=str, default="", + help="target") + parser.add_argument("bitstream", type=str, default="", + help="bitstream path") + args = parser.parse_args() + + if (args.target != 'pynq' and args.target != 'sim'): + raise RuntimeError("Unknown target {}".format(args.target)) + + curr_path = os.path.dirname( + os.path.abspath(os.path.expanduser(__file__))) + path_list = [ + os.path.join(curr_path, "/{}".format(args.bitstream)), + os.path.join('./', "{}".format(args.bitstream)) + ] + ok_path_list = [p for p in path_list if os.path.exists(p)] + if not ok_path_list: + raise RuntimeError("Cannot find bitstream file in %s" % str(path_list)) + + bitstream_program(args.target, args.bitstream) + +def pynq_bitstream_program(bitstream_path): + from pynq import Bitstream + bitstream = Bitstream(bitstream_path) + bitstream.download() + +def bitstream_program(target, bitstream): + if target == 'pynq': + pynq_bitstream_program(bitstream) + elif target != 'sim': + raise RuntimeError("Unknown target {}".format(target)) + +if __name__ == "__main__": + main() diff --git a/vta/tests/hardware/common/test_lib.cc b/vta/tests/hardware/common/test_lib.cc index 291016a4ef3f..e88cede4d055 100644 --- a/vta/tests/hardware/common/test_lib.cc +++ b/vta/tests/hardware/common/test_lib.cc @@ -52,12 +52,6 @@ uint64_t vta( snprintf(str_block_bit_width, sizeof(str_block_bit_width), "%d", VTA_WGT_WIDTH); snprintf(bitstream, sizeof(bitstream), "%s", "vta.bit"); -#if VTA_DEBUG == 1 - printf("INFO - Programming FPGA: %s!\n", bitstream); -#endif - - // Program VTA - VTAProgram(bitstream); // Get VTA handles void* vta_fetch_handle = VTAMapRegister(VTA_FETCH_ADDR, VTA_RANGE); void* vta_load_handle = VTAMapRegister(VTA_LOAD_ADDR, VTA_RANGE); diff --git a/vta/tests/hardware/metal_test/Makefile b/vta/tests/hardware/metal_test/Makefile index 67563f324734..ef1dfc274916 100644 --- a/vta/tests/hardware/metal_test/Makefile +++ b/vta/tests/hardware/metal_test/Makefile @@ -18,7 +18,7 @@ CC ?= g++ CFLAGS = -Wall -O3 -std=c++11 -I/usr/include LDFLAGS = -L/usr/lib -L/opt/python3.6/lib/python3.6/site-packages/pynq/lib/ -LIBS = -l:libsds_lib.so -l:libdma.so -lstdc++ +LIBS = -l:libcma.so -lstdc++ -pthread INCLUDE_DIR = ../../../include DRIVER_DIR = ../../../src/pynq TESTLIB_DIR = ../common @@ -33,11 +33,15 @@ CFLAGS += `${VTA_CONFIG} --cflags` LDFLAGS += `${VTA_CONFIG} --ldflags` VTA_TARGET := $(shell ${VTA_CONFIG} --target) +# Include bitstream +VTA_PROGRAM = python3 ../../../python/vta/program_bitstream.py +VTA_BIT = "vta.bit" + # Define flags CFLAGS += -I $(INCLUDE_DIR) -DNO_SIM -DVTA_DEBUG=0 # All Target -all: $(EXECUTABLE) +all: vtainstall $(EXECUTABLE) %.o: %.cc $(SOURCES) $(CC) -c -o $@ $< $(CFLAGS) @@ -45,5 +49,7 @@ all: $(EXECUTABLE) $(EXECUTABLE): $(OBJECTS) $(CC) $(LDFLAGS) $(OBJECTS) -o $@ $(LIBS) +vtainstall: + ${VTA_PROGRAM} ${VTA_TARGET} ${VTA_BIT} clean: rm -rf *.o $(EXECUTABLE)