Skip to content

Commit

Permalink
merge setup.py
Browse files Browse the repository at this point in the history
  • Loading branch information
bnellnm committed Mar 18, 2024
1 parent 5393d4c commit af254ce
Showing 1 changed file with 18 additions and 27 deletions.
45 changes: 18 additions & 27 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,31 +184,10 @@ def _is_neuron() -> bool:
return torch_neuronx_installed


def _is_cuda() -> bool:
return (torch.version.cuda is not None) and not _is_neuron()


def _install_punica() -> bool:
return bool(int(os.getenv("VLLM_INSTALL_PUNICA_KERNELS", "0")))


def get_path(*filepath) -> str:
return os.path.join(ROOT_DIR, *filepath)


def find_version(filepath: str) -> str:
"""Extract version information from the given filepath.
Adapted from https://github.com/ray-project/ray/blob/0b190ee1160eeca9796bc091e07eaebf4c85b511/python/setup.py
"""
with open(filepath) as fp:
version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]",
fp.read(), re.M)
if version_match:
return version_match.group(1)
raise RuntimeError("Unable to find version string.")


def get_hipcc_rocm_version():
# Run the hipcc --version command
result = subprocess.run(['hipcc', '--version'],
Expand Down Expand Up @@ -263,11 +242,28 @@ def get_nvcc_cuda_version() -> Version:
return nvcc_cuda_version


def get_path(*filepath) -> str:
return os.path.join(ROOT_DIR, *filepath)


def find_version(filepath: str) -> str:
"""Extract version information from the given filepath.
Adapted from https://github.com/ray-project/ray/blob/0b190ee1160eeca9796bc091e07eaebf4c85b511/python/setup.py
"""
with open(filepath) as fp:
version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]",
fp.read(), re.M)
if version_match:
return version_match.group(1)
raise RuntimeError("Unable to find version string.")


def get_vllm_version() -> str:
version = find_version(get_path("vllm", "__init__.py"))

if _is_cuda():
cuda_version = str(nvcc_cuda_version)
cuda_version = str(get_nvcc_cuda_version())
if cuda_version != MAIN_CUDA_VERSION:
cuda_version_str = cuda_version.replace(".", "")[:3]
version += f"+cu{cuda_version_str}"
Expand All @@ -283,11 +279,6 @@ def get_vllm_version() -> str:
if neuron_version != MAIN_CUDA_VERSION:
neuron_version_str = neuron_version.replace(".", "")[:3]
version += f"+neuron{neuron_version_str}"
elif _is_cuda():
cuda_version = str(get_nvcc_cuda_version())
if cuda_version != MAIN_CUDA_VERSION:
cuda_version_str = cuda_version.replace(".", "")[:3]
version += f"+cu{cuda_version_str}"
else:
raise RuntimeError("Unknown runtime environment")

Expand Down

0 comments on commit af254ce

Please sign in to comment.