From cef69d8f6c8dc5ee051517c759024dea364b3e6d Mon Sep 17 00:00:00 2001 From: Harsh Menon Date: Fri, 11 Oct 2024 00:18:16 -0700 Subject: [PATCH] Add wave packaging test This test shows how to package kernels into pip packages for deployment. --- .../kernel/wave/packaging/build_package.py | 49 ++ .../wave/packaging/templates/main.py.j2 | 496 ++++++++++++++++++ .../wave/packaging/templates/setup.py.j2 | 13 + tests/kernel/wave/wave_packaging_test.py | 140 +++++ 4 files changed, 698 insertions(+) create mode 100644 iree/turbine/kernel/wave/packaging/build_package.py create mode 100644 iree/turbine/kernel/wave/packaging/templates/main.py.j2 create mode 100644 iree/turbine/kernel/wave/packaging/templates/setup.py.j2 create mode 100644 tests/kernel/wave/wave_packaging_test.py diff --git a/iree/turbine/kernel/wave/packaging/build_package.py b/iree/turbine/kernel/wave/packaging/build_package.py new file mode 100644 index 00000000..6da47254 --- /dev/null +++ b/iree/turbine/kernel/wave/packaging/build_package.py @@ -0,0 +1,49 @@ +import pathlib +from typing import Any +import jinja2 +import shutil + + +def build_folders(kernel_info: dict[str, Any], output_dir: str): + package_path = pathlib.Path(output_dir) / kernel_info["package_name"] + package_path.mkdir(parents=True, exist_ok=True) + subfolder = package_path / kernel_info["package_name"] + subfolder.mkdir(parents=True, exist_ok=True) + init_file = subfolder / "__init__.py" + with open(init_file, "w") as f: + f.write(f"from .main import {kernel_info['kernel_name']}\n") + return subfolder + + +def copy_artifacts(kernel_info: dict[str, Any], output_dir: str): + shutil.copy(kernel_info["vmfb_path"], output_dir) + + +def render_templates(kernel_info: dict[str, Any], output_dir: str): + parent_dir = pathlib.Path(__file__).resolve().parent + template_loader = jinja2.FileSystemLoader(searchpath=parent_dir / "templates") + template_env = jinja2.Environment(loader=template_loader) + main_template = template_env.get_template("main.py.j2") + updated_template = main_template.render( + kernel_function_name=kernel_info["kernel_name"], + kernel_num_inputs=kernel_info["num_inputs"], + kernel_dispatch_name=kernel_info["dispatch_name"], + vmfb_path=pathlib.Path(kernel_info["vmfb_path"]).name, + ) + with open(output_dir / "main.py", "w") as f: + f.write(updated_template) + setup_template = template_env.get_template("setup.py.j2") + updated_template = setup_template.render( + kernel_package_name=kernel_info["package_name"], + kernel_version=kernel_info["kernel_version"], + ) + with open(output_dir.parents[0] / "setup.py", "w") as f: + f.write(updated_template) + + +def create_pip_package(kernel_info: dict[str, Any], output_dir: str): + """Builds a pip package from the current directory.""" + + subfolder = build_folders(kernel_info, output_dir) + copy_artifacts(kernel_info, subfolder) + render_templates(kernel_info, subfolder) diff --git a/iree/turbine/kernel/wave/packaging/templates/main.py.j2 b/iree/turbine/kernel/wave/packaging/templates/main.py.j2 new file mode 100644 index 00000000..1f6862fd --- /dev/null +++ b/iree/turbine/kernel/wave/packaging/templates/main.py.j2 @@ -0,0 +1,496 @@ +#!/usr/bin/env python3 +# Do not modify this file. +# This file is automatically generated from a template in iree/turbine/kernel/wave/packaging/templates/main.py. +# ========================================================================================== + +from functools import lru_cache +import iree.runtime as rt +from typing import Callable, Optional, Union +from threading import local, Lock +import warnings +from iree.runtime import ( + BufferUsage, + HalBufferView, + HalDevice, + HalDriver, + MemoryType, + VmInstance, + VmModule, + create_hal_module, + get_driver, +) +import torch + +_CURRENT_THREAD = local() +_CONFIG_LOCK = Lock() +_GLOBAL_VM_INSTANCE: Optional[VmInstance] = None + + +class MismatchedDeviceSetClearError(AssertionError): + def __init__(self): + super().__init__("Calls to Device.set()/clear() are mismatched or unbalanced.") + + +class UnsupportedTorchDeviceError(Exception): + def __init__(self, torch_device): + super().__init__( + f"Attempt to use turbine with a torch.device that is not supported by this build: {torch_device}" + ) + + +class NoCurrentDeviceError(Exception): + def __init__(self): + super().__init__( + "You accessed a method which requires a current device but none was set on this thread. " + "Either pass an explicit 'device=' or set a current device via " + "`with device:`" + ) + + +def get_vm_instance() -> VmInstance: + global _GLOBAL_VM_INSTANCE + if not _GLOBAL_VM_INSTANCE: + with _CONFIG_LOCK: + if not _GLOBAL_VM_INSTANCE: + _GLOBAL_VM_INSTANCE = VmInstance() + return _GLOBAL_VM_INSTANCE + + +class DeviceState: + """State for an instantiated HAL device. + + Note that the IREE runtime internally manages a global cache of drivers for + standard named-access (not custom-constructed) drivers. + """ + + __slots__ = [ + "device", + "driver", + "instance", + "enumerated_info", + "torch_device", + "dlpack_device_type_code", + ] + + def __init__( + self, + *, + driver: Union[str, HalDriver], + device: Optional[HalDevice] = None, + vm_instance: Optional[VmInstance] = None, + enumerated_info: Optional[dict] = None, + torch_device: Optional[torch.device] = None, + dlpack_device_type_code: int = 0, + ): + self.instance = vm_instance or get_vm_instance() + self.driver = driver if isinstance(driver, HalDriver) else get_driver(driver) + self.device = device if device else self.driver.create_default_device() + self.enumerated_info = enumerated_info or {} + self.torch_device = torch_device + self.dlpack_device_type_code = dlpack_device_type_code + + @property + def enumerated_device_id(self) -> int: + try: + return self.enumerated_info["device_id"] + except KeyError as e: + raise RuntimeError("No enumerated device_id for device") from e + + @property + def enumerated_path(self) -> str: + try: + return self.enumerated_info["path"] + except KeyError as e: + raise RuntimeError("No enumerated path for device") from e + + @property + def enumerated_name(self) -> str: + try: + return self.enumerated_info["name"] + except KeyError as e: + raise RuntimeError("No enumerated name for device") from e + + @staticmethod + @lru_cache(maxsize=None) + def from_uri(uri: str) -> "DeviceState": + driver = get_driver(uri) + return DeviceState(driver=driver, device=driver.create_device_by_uri(uri)) + + +class Device: + """Represents a low-level device (HalDriver/HalDevice) and scheduling data. + + This is the type that user's interact with as a 'Device'. Devices can be handled + loose-leaf or bound to a thread with a context manager. + """ + + __slots__ = [ + "_s", + "_main_timeline", + "_main_timepoint", + "_tx_timeline", + "_tx_timepoint", + "_fence_capacity", + "compile_target_flags", + "driver_id", + "export_torch_tensor", + "import_torch_tensor", + "instance_cache_key", + "type_cache_key", + ] + + _s: DeviceState + + # Each device will have a function attached to import a torch.tensor + # *that is already on that device* directly from device memory. + # This is unsafe and relatively unchecked. If criss-crossing devices, + # it is undefined behavior. + import_torch_tensor: Callable[[torch.Tensor], HalBufferView] + + # Devices can also export a torch tensor from a HalBufferView, given + # a meta tensor that describes it. + export_torch_tensor: Callable[[HalBufferView, torch.Tensor], torch.Tensor] + + # Unique name of the IREE runtime driver associated with this device. + driver_id: str + + # Cache key that uniquely identifies this device. + instance_cache_key: str + + # Cache key that uniquely identifies this type of device (currently + # based on its driver). + type_cache_key: str + + # Compiler flags to use to target this device. + # TODO: We should replace this with a target attribute but need an API + # to derive that. + compile_target_flags: tuple[str, ...] + + def __new__( + cls, + uri: Optional[str] = None, + *, + device_state: Optional[DeviceState] = None, + ): + if uri is not None: + # Construction by URI is cached on the thread. + assert not device_state, "device_state= cannot be given with explicit URI" + try: + existing = _CURRENT_THREAD.device_by_uri[uri] + except (AttributeError, KeyError): + ... + else: + return existing + + # New instance. + device_state = DeviceState.from_uri(uri) + new_inst = super().__new__(cls) + new_inst._s = device_state + try: + _CURRENT_THREAD.device_by_uri[uri] = new_inst + except AttributeError: + _CURRENT_THREAD.device_by_uri = {uri: new_inst} + new_inst._initialize() + return new_inst + else: + # Explicit construction with a device_state is assumed that you know what you + # are doing and an uncached instance will be returned. This will be unsychronized + # relative to any cached instance. + assert device_state, "device_state= must be given if URI ommitted" + new_inst = super().__new__(cls) + new_inst._s = device_state + new_inst._initialize() + return new_inst + + def _initialize(self): + d = self._s.device + self._main_timeline = d.create_semaphore(0) + self._main_timepoint = 0 + self._tx_timeline = d.create_semaphore(0) + self._tx_timepoint = 0 + # Maximum number of semaphores the device uses. Can be increased if doing out of the + # ordinary scheduling. + self._fence_capacity = 2 + + # Perform driver specific augmentations. + # TODO: Add a HalDriver.id property to get the driver name instead of parsing + # the device repr. + driver_id = repr(d) + colon_pos = driver_id.find(":") + if colon_pos >= 0: + driver_id = driver_id[0:colon_pos] + self.driver_id = driver_id + try: + import_fn = TORCH_TENSOR_IMPORTERS[driver_id] + export_fn = TORCH_TENSOR_EXPORTERS[driver_id] + self.import_torch_tensor = lambda t: import_fn(self, t) + self.export_torch_tensor = lambda bv, t: export_fn(self, bv, t) + except KeyError as e: + raise AssertionError( + f"Unsupported TORCH_TENSOR_IMPORTERS for iree driver '{driver_id}'" + ) from e + + # Cache keys. + # TODO: The type cache key should actually be based on the driver id + # and device characteristics hash. + self.instance_cache_key = repr(d) + self._recompute_target_keys() + + def _recompute_target_keys(self): + self.type_cache_key = f"{self.driver_id}:{';'.join(self.compile_target_flags)}" + + @property + def hal_device(self) -> HalDevice: + return self._s.device + + @property + def vm_instance(self) -> VmInstance: + return self._s.instance + + def create_hal_module(self) -> VmModule: + s = self._s + return create_hal_module(s.instance, s.device) + + @staticmethod + def current() -> "Device": + try: + return _CURRENT_THREAD.stack[-1] + except (AttributeError, IndexError): + raise NoCurrentDeviceError() + + def set(self) -> "Device": + """Sets this device as the current device without a context manager.""" + try: + _CURRENT_THREAD.stack.append(self) + except AttributeError: + _CURRENT_THREAD.stack = [self] + return self + + def clear(self): + """Clears the current device without a context manager.""" + try: + c = _CURRENT_THREAD.stack[-1] + if _CURRENT_THREAD.stack[-1] is self: + _CURRENT_THREAD.stack.pop() + return + except (AttributeError, IndexError): + ... + raise MismatchedDeviceSetClearError() + + def dump_device_info(self) -> str: + return self._s.driver.dump_device_info(self._s.enumerated_device_id) + + def __repr__(self): + return f"" + + def __enter__(self): + try: + _CURRENT_THREAD.stack.append(self) + except AttributeError: + _CURRENT_THREAD.stack = [self] + + def __exit__(self, type, value, traceback): + _CURRENT_THREAD.stack.pop() + + +################################################################################ +# CUDA and HIP import/export +################################################################################ + + +def _device_import_torch_tensor_cuda_hip( + device: Device, t: torch.Tensor +) -> HalBufferView: + # We currently only support contiguous, so ensure that. + if not t.is_contiguous(): + t = t.contiguous() + # TODO: The 'None' here tells the producer to synchronize on the default + # stream. For async, we should advance our timeline and signal when an + # event is raised on Torch's stream at the current position. + capsule = t.__dlpack__(None) + bv = device.hal_device.from_dlpack_capsule(capsule) + return bv + + +def _device_export_torch_tensor_cuda_hip( + device: Device, bv: HalBufferView, like: torch.Tensor +) -> torch.Tensor: + state = device._s + device_type_code = state.dlpack_device_type_code + assert device_type_code > 0 + torch_device = state.torch_device + assert torch_device is not None + device_index = torch_device.index + t = torch.from_dlpack( + device.hal_device.create_dlpack_capsule(bv, device_type_code, device_index) + ) + if t.dtype != like.dtype: + t = t.view(like.dtype) + # TODO: For async, we should enqueue an event on Torch's stream which will + # signal when this tensor is produced (i.e. at the current point in our + # timeline). + return t + + +# Mapping of torch tensor importers keyed by driver name. +TORCH_TENSOR_IMPORTERS: dict[str, Callable[[Device, torch.Tensor], HalBufferView]] = { + "cuda": _device_import_torch_tensor_cuda_hip, + "hip": _device_import_torch_tensor_cuda_hip, +} + +TORCH_TENSOR_EXPORTERS: dict[ + str, Callable[[Device, HalBufferView, torch.Tensor], torch.Tensor] +] = { + "cuda": _device_export_torch_tensor_cuda_hip, + "hip": _device_export_torch_tensor_cuda_hip, +} + +############################################################################### +# torch.device to Device mapping +############################################################################### + + +def lookup_device_from_torch( + torch_device: torch.device, *, create: bool = True +) -> Optional[Device]: + """Gets a shared Device corresponding to the given torch.device. + + This will return None if the device is wholly unsupported or if + create=False. Otherwise, faults in setting up the device are + reported as an appropriate exception. + """ + try: + mapping = _CURRENT_THREAD.device_by_torch_device + except AttributeError: + _CURRENT_THREAD.device_by_torch_device = mapping = {} + device = mapping.get(torch_device) + if device is not None or not create: + return device + device = _create_device_from_torch(torch_device) + if device is not None: + mapping[torch_device] = device + return device + + +def get_device_from_torch(torch_device: torch.device) -> Device: + """Gets a shared Device corresponding to the given torch.device. + + Raises an exception if the device cannot be created. + """ + device = lookup_device_from_torch(torch_device) + if device is None: + raise UnsupportedTorchDeviceError(torch_device) + return device + + +def _create_device_from_torch(torch_device: torch.device) -> Optional[Device]: + torch_type = torch_device.type + if torch_type == "cuda": + # Fork based on HIP or real CUDA. + props = torch.cuda.get_device_properties(torch_device) + if not hasattr(props, "gcnArchName"): + # Real CUDA. + return _create_cuda_device(torch_device, props) + else: + # HIP as CUDA. + return _create_hip_device(torch_device, props) + + return None + + +def _create_cuda_device(torch_device: torch.device, props) -> Optional[Device]: + # Note that the dlpack device type code for real CUDA ROCM is 2. + device = _create_cuda_like_device(torch_device, props, "hip", 2) + if device: + device.compile_target_flags = device.compile_target_flags + ( + f"--iree-hal-cuda-llvm-target-arch=sm_{props.major}{props.minor}", + ) + device._recompute_target_keys() + return device + + +def _create_hip_device(torch_device: torch.device, props) -> Optional[Device]: + # Note that the dlpack device type code for ROCM is 10. + device = _create_cuda_like_device(torch_device, props, "hip", 10) + # The gcnArchName comes back like gfx90a:sramecc+:xnack- for a fully + # specified target. However the IREE target-chip flag only expects the + # prefix. See: https://github.com/iree-org/iree/issues/17402 + # This should be changed to tunnel through target information unmolested. + gcn_arch_name: str = props.gcnArchName + colon_pos = gcn_arch_name.find(":") + if colon_pos >= 0: + gcn_arch_name = gcn_arch_name[0:colon_pos] + if device: + gcn_arch_name = gcn_arch_name + device.compile_target_flags = device.compile_target_flags + ( + f"--iree-rocm-target-chip={gcn_arch_name}", + ) + device._recompute_target_keys() + return device + + +def _create_cuda_like_device( + torch_device: torch.device, props, driver_name: str, dlpack_device_type_code: int +) -> Optional[Device]: + if torch.cuda.device_count() > 1: + warnings.warn( + f"Multiple {driver_name} devices detected: Turbine does not yet " + f"guarantee stable device mapping" + ) + + requested_index = torch_device.index + driver = get_driver(driver_name) + available_infos = driver.query_available_devices() + if requested_index >= len(available_infos): + return None + device_info = available_infos[requested_index] + hal_device = driver.create_device(device_info) + device_state = DeviceState( + driver=driver, + device=hal_device, + vm_instance=get_vm_instance(), + enumerated_info=device_info, + torch_device=torch_device, + dlpack_device_type_code=dlpack_device_type_code, + ) + device = Device(device_state=device_state) + return device + + +@lru_cache(maxsize=None) +def module(device: Device): + return rt.VmModule.mmap(device._s.instance, "{{vmfb_path}}") + +@lru_cache(maxsize=None) +def context(device: Device): + return rt.VmContext( + device._s.instance, + (create_hal_module(device._s.instance, device._s.device), module(device)), + ) + + +@lru_cache(maxsize=None) +def func(device, name): + return module(device).lookup_function(name) + + +def {{kernel_function_name}}(*args): + arg_list, ret_list = [], [] + device = None + num_inputs = {{kernel_num_inputs}} + num_outputs = len(args) - num_inputs + for i, arg in enumerate(args): + if device is None: + device = get_device_from_torch(args[0].device) + assert device is not None, "Device not found" + if i < num_inputs: + arg_list.append(TORCH_TENSOR_IMPORTERS[arg.dtype](device, arg)) + else: + ret_list.append(TORCH_TENSOR_IMPORTERS[arg.dtype](device, arg)) + context(device).vm_context.invoke( + func(device, {{kernel_dispatch_name}}), arg_list, ret_list + ) + return_values = [] + for ret in ret_list: + return_values.append(TORCH_TENSOR_EXPORTERS[ret.dtype](device, ret)) + return return_values[0] if num_outputs == 1 else return_values diff --git a/iree/turbine/kernel/wave/packaging/templates/setup.py.j2 b/iree/turbine/kernel/wave/packaging/templates/setup.py.j2 new file mode 100644 index 00000000..77b9f37a --- /dev/null +++ b/iree/turbine/kernel/wave/packaging/templates/setup.py.j2 @@ -0,0 +1,13 @@ +from setuptools import setup, find_packages + +setup( + name="{{kernel_package_name}}", + version="{{kernel_version}}", + packages=find_packages(), + include_package_data=True, + package_data={"": ["*.vmfb"]}, + install_requires=[ + "iree-runtime==20240918.1020", + "iree-turbine@git+https://github.com/iree-org/iree-turbine.git@main", + ], +) diff --git a/tests/kernel/wave/wave_packaging_test.py b/tests/kernel/wave/wave_packaging_test.py new file mode 100644 index 00000000..4c4f71f2 --- /dev/null +++ b/tests/kernel/wave/wave_packaging_test.py @@ -0,0 +1,140 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import logging +import pytest +import torch +import unittest +import iree.turbine.kernel as tk +import iree.turbine.kernel.lang as tkl +import iree.turbine.kernel.wave as tkw +from iree.turbine.kernel.lang.global_symbols import * +from iree.turbine.kernel.wave.packaging.build_package import create_pip_package +import os +import json +from torch.testing import assert_close + + +def packageTest(): + shape = (2048, 1280, 1280) + enable_scheduling = True + # Input sizes + M = tkl.sym.M + N = tkl.sym.N + K = tkl.sym.K + # Workgroup tile sizes + BLOCK_M = tkl.sym.BLOCK_M + BLOCK_N = tkl.sym.BLOCK_N + BLOCK_K = tkl.sym.BLOCK_K + # Address space (for GPU, shared(1) or global(0)) + ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE + # Other hyperparameters + LOAD_ELEMS_PER_THREAD = tkl.sym.LOAD_ELEMS_PER_THREAD + STORE_ELEMS_PER_THREAD = tkl.sym.STORE_ELEMS_PER_THREAD + + # Expose user-constraints + constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + constraints += [tkw.TilingConstraint(K, BLOCK_K)] + constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)] + constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)] + + constraints += [ + tkw.HardwareConstraint(threads_per_wave=64, waves_per_block=(2, 2, 1)) + ] + + # Wave-level micro-kernel. + # Since warps are not directly addressable, there is no + # explicit notion of a warp id (like a workgroup or thread id). + # This kernel uses the input sizes M, N, K throughout, as the tiling + # and data movement strategy is determined during the compilation process. + # These can be influenced by introducing constraints. + @tkw.wave(constraints) + def gemm( + a: tkl.Memory[M, K, ADDRESS_SPACE, tkl.f16], + b: tkl.Memory[N, K, ADDRESS_SPACE, tkl.f16], + c: tkl.Memory[M, N, GLOBAL_ADDRESS_SPACE, tkl.f32], + ): + c_reg = tkl.Register[M, N, tkl.f32](0.0) + + # This microkernel encodes the fact that if the reduction + # dimension were tiled, then we would need to materialize a loop. + @tkw.reduction(K, init_args=[c_reg]) + def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: + # a_reg: tkw.Register[M, K, tkl.f16] + a_reg = tkw.read(a, elements_per_thread=LOAD_ELEMS_PER_THREAD) + # b_reg: tkw.Register[N, K, tkl.f16] + b_reg = tkw.read(b, elements_per_thread=LOAD_ELEMS_PER_THREAD) + # acc: tkw.Register[M, N, tkl.f32] + acc = tkw.mma(a_reg, b_reg, acc) + return acc + + # repeat represents the results of the loop + tkw.write(repeat, c, elements_per_thread=STORE_ELEMS_PER_THREAD) + + hyperparams = { + ADDRESS_SPACE: SHARED_ADDRESS_SPACE, + LOAD_ELEMS_PER_THREAD: 4, + STORE_ELEMS_PER_THREAD: 4, + BLOCK_M: 64, + BLOCK_N: 64, + BLOCK_K: 32, + M: shape[0], + N: shape[1], + K: shape[2], + READ_SHARED_DELAY: 1, + WRITE_SHARED_DELAY: 1, + READ_GLOBAL_DELAY: 2, + WRITE_GLOBAL_DELAY: 2, + MMA_DELAY: 1, + SHARED_MEMORY_UNITS: 4, + GLOBAL_MEMORY_UNITS: 4, + MMA_UNITS: 4, + } + config = { + "backend": "rocm", + "device": "hip", + "target": "gfx942", + "dump_vmfb_file": "artifacts/kernel.vmfb", + } + + with tk.gen.TestLaunchContext( + hyperparams, + canonicalize=True, + run=True, + run_bench=False, + run_config=config, + schedule=enable_scheduling, + ): + a = torch.randn(shape[0], shape[2], dtype=torch.float16) + b = torch.randn(shape[1], shape[2], dtype=torch.float16) + c = torch.zeros(shape[0], shape[1], dtype=torch.float32) + gemm(a, b, c) + + # Create the pip package + kernel_info = { + "package_name": "libtkw", + "kernel_name": "gemm_f32_2048x1280x1280_f16", + "num_inputs": 2, + "dispatch_name": "isolated_benchmark", + "vmfb_path": "artifacts/kernel.vmfb", + "kernel_version": "0.0.1", + } + create_pip_package(output_dir="pip_package/", kernel_info=kernel_info) + # Run python setup.py bdist_wheel in pip_package/ to build the wheel. + # Once the wheel is built, it can be installed using + # pip install .whl --find-links https://iree.dev/pip-release-links.html + # The kernel can then be invoked from Python as follows: + # import libtkw + # import torch + # a = torch.randn(2048, 1280, dtype=torch.float16, device="cuda") + # b = torch.randn(1280, 1280, dtype=torch.float16, device="cuda") + # c = torch.empty(2048, 1280, dtype=torch.float32, device="cuda") + # libtkw.gemm_f32_2048x1280x1280_f16(a, b, c) + assert os.path.exists("pip_package/libtkw-0.0.1-py3-none-any.whl") + + +packageTest()