Skip to content

Commit

Permalink
change copy to include
Browse files Browse the repository at this point in the history
  • Loading branch information
mehrdadh committed May 24, 2022
1 parent 724c09a commit b31a618
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 36 deletions.
48 changes: 18 additions & 30 deletions apps/microtvm/zephyr/template_project/microtvm_api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,11 +334,11 @@ def get_zephyr_base(options: dict):
return zephyr_base


def get_cmsis_path(options: dict):
def get_cmsis_path(options: dict) -> pathlib.Path:
"""Returns CMSIS dependency path"""
cmsis_path = options.get("cmsis_path", CMSIS_PATH)
assert cmsis_path, "'cmsis_path' option not passed and not found by default!"
return cmsis_path
return pathlib.Path(cmsis_path)


class Handler(server.ProjectAPIHandler):
Expand Down Expand Up @@ -432,28 +432,6 @@ def _get_platform_version(self, zephyr_base: str) -> float:

return float(f"{version_major}.{version_minor}")

def _load_cmsis(self, lib_path: Union[str, pathlib.Path], cmsis_path: Union[str, pathlib.Path]):
"""Copy CMSIS header files to generated project."""

cmsis_path = pathlib.Path(cmsis_path)

lib_path = pathlib.Path(lib_path)
if not lib_path.exists():
lib_path.mkdir()

include_directories = ["CMSIS/DSP/Include", "CMSIS/DSP/Include/dsp", "CMSIS/NN/Include"]
for include_path in include_directories:
include_path = pathlib.Path(include_path)
src = cmsis_path / include_path
dest = lib_path
if include_path.name != "Include":
dest = lib_path / include_path.name
dest.mkdir()

for item in src.iterdir():
if not item.is_dir():
shutil.copy(item, dest / item.name)

def _cmsis_required(self, project_path: Union[str, pathlib.Path]) -> bool:
"""Check if CMSIS dependency is required."""
project_path = pathlib.Path(project_path)
Expand Down Expand Up @@ -496,10 +474,6 @@ def generate_project(self, model_library_format_path, standalone_crt_dir, projec
os.makedirs(extract_path)
tf.extractall(path=extract_path)

# Add CMSIS libraries if required.
if self._cmsis_required(extract_path):
self._load_cmsis(pathlib.Path(project_dir) / "include", get_cmsis_path(options))

if self._is_qemu(options):
shutil.copytree(API_SERVER_DIR / "qemu-hack", project_dir / "qemu-hack")

Expand All @@ -515,8 +489,8 @@ def generate_project(self, model_library_format_path, standalone_crt_dir, projec
shutil.copy2(src_path, dst_path)

# Populate Makefile.
with open(API_SERVER_DIR / "CMakeLists.txt.template", "r") as cmake_template_f:
with open(project_dir / "CMakeLists.txt", "w") as cmake_f:
with open(project_dir / "CMakeLists.txt", "w") as cmake_f:
with open(API_SERVER_DIR / "CMakeLists.txt.template", "r") as cmake_template_f:
for line in cmake_template_f:
if self.API_SERVER_CRT_LIBS_TOKEN in line:
crt_libs = self.CRT_LIBS_BY_PROJECT_TYPE[options["project_type"]]
Expand All @@ -529,6 +503,20 @@ def generate_project(self, model_library_format_path, standalone_crt_dir, projec
for item in flags:
cmake_f.write(f"target_compile_definitions(app PUBLIC {item})\n")

# Include CMSIS libraries if required.
if self._cmsis_required(extract_path):
cmsis_path = get_cmsis_path(options)
cmake_f.write("\n")
cmake_f.write(
f'target_include_directories(tvm_model PRIVATE {str(cmsis_path / "CMSIS" / "DSP" / "Include")})\n'
)
cmake_f.write(
f'target_include_directories(tvm_model PRIVATE {str(cmsis_path / "CMSIS" / "DSP" / "Include" / "dsp")})\n'
)
cmake_f.write(
f'target_include_directories(tvm_model PRIVATE {str(cmsis_path / "CMSIS" / "NN" / "Include")})\n'
)

self._create_prj_conf(project_dir, options)

# Populate crt-config.h
Expand Down
12 changes: 6 additions & 6 deletions tests/micro/zephyr/test_zephyr.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,13 +564,13 @@ def test_schedule_build_with_cmsis_dependency(temp_dir, board, west_cmd, tvm_deb
)
project.build()

generated_project_include_files = []
for path in (project_dir / "include").iterdir():
if path.is_file():
generated_project_include_files.append(path.name)
with open(project_dir / "CMakeLists.txt", "r") as cmake_f:
cmake_content = cmake_f.read()

assert "arm_math.h" in generated_project_include_files
assert "arm_nnsupportfunctions.h" in generated_project_include_files
assert "CMSIS/DSP/Include" in cmake_content
assert "CMSIS/DSP/Include/dsp" in cmake_content
assert "CMSIS/DSP/Include" in cmake_content
assert "CMSIS/NN/Include" in cmake_content


if __name__ == "__main__":
Expand Down

0 comments on commit b31a618

Please sign in to comment.