Skip to content

Commit

Permalink
Simplify setup.py
Browse files Browse the repository at this point in the history
* Remove a hack to use `nvcc` provided from env variable.
This was to support `ccache`, but is no longer needed with PyTorch 2.1+
* Write the `version.py` file (which provides `__version__`) when we build
There was some complexity before, because we wanted the version for source-distr
to not be re-calculated when we install.
Now in the CI we set `version.txt` to the actual version and use that at build time.

ghstack-source-id: 19a3d43b48f7b835c97b4fe13f50cb58e7e4d540
Pull Request resolved: https://github.com/fairinternal/xformers/pull/945

__original_commit__ = fairinternal/xformers@56c13f872bf139826739adf1dad7ba307cc666c8
  • Loading branch information
xFormers Bot committed Nov 29, 2023
1 parent 5c4ee77 commit b631a0e
Showing 1 changed file with 36 additions and 61 deletions.
97 changes: 36 additions & 61 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import datetime
import distutils.command.clean
import glob
import importlib.util
import json
import os
import platform
Expand Down Expand Up @@ -49,6 +48,9 @@ def fetch_requirements():


def get_local_version_suffix() -> str:
if not (Path(__file__).parent / ".git").is_dir():
# Most likely installing from a source distribution
return ""
date_suffix = datetime.datetime.now().strftime("%Y%m%d")
git_hash = subprocess.check_output(
["git", "rev-parse", "--short", "HEAD"], cwd=Path(__file__).parent
Expand All @@ -70,14 +72,13 @@ def get_flash_version() -> str:
return "v?"


def write_version_file(version: str):
version_path = os.path.join(this_dir, "xformers", "version.py")
with open(version_path, "w") as f:
f.write("# noqa: C801\n")
f.write(f'__version__ = "{version}"\n')
tag = os.getenv("GIT_TAG")
if tag is not None:
f.write(f'git_tag = "{tag}"\n')
def generate_version_py(version: str) -> str:
content = "# noqa: C801\n"
content += f'__version__ = "{version}"\n'
tag = os.getenv("GIT_TAG")
if tag is not None:
content += f'git_tag = "{tag}"\n'
return content


def symlink_package(name: str, path: Path, is_building_wheel: bool) -> None:
Expand Down Expand Up @@ -339,74 +340,44 @@ def run(self):
distutils.command.clean.clean.run(self)


class BuildExtensionWithMetadata(BuildExtension):
class BuildExtensionWithExtraFiles(BuildExtension):
def __init__(self, *args, **kwargs) -> None:
self.xformers_build_metadata = kwargs.pop("xformers_build_metadata")
self.xformers_build_metadata = kwargs.pop("extra_files")
self.pkg_name = "xformers"
self.metadata_json = "cpp_lib.json"
super().__init__(*args, **kwargs)

@staticmethod
def _join_cuda_home(*paths) -> str:
"""
Hackfix to support custom `nvcc` binary (eg ccache)
TODO: Remove once we use PT 2.1.0 (https://github.com/pytorch/pytorch/pull/96987)
"""
if paths == ("bin", "nvcc") and "PYTORCH_NVCC" in os.environ:
return os.environ["PYTORCH_NVCC"]
if CUDA_HOME is None:
raise EnvironmentError(
"CUDA_HOME environment variable is not set. "
"Please set it to your CUDA install root."
)
return os.path.join(CUDA_HOME, *paths)

def build_extensions(self) -> None:
torch.utils.cpp_extension._join_cuda_home = (
BuildExtensionWithMetadata._join_cuda_home
)
super().build_extensions()
with open(
os.path.join(self.build_lib, self.pkg_name, self.metadata_json), "w+"
) as fp:
json.dump(self.xformers_build_metadata, fp)
for filename, content in self.xformers_build_metadata.items():
with open(
os.path.join(self.build_lib, self.pkg_name, filename), "w+"
) as fp:
fp.write(content)

def copy_extensions_to_source(self):
def copy_extensions_to_source(self) -> None:
"""
Used for `pip install -e .`
Copies everything we built back into the source repo
"""
build_py = self.get_finalized_command("build_py")
package_dir = build_py.get_package_dir(self.pkg_name)
inplace_file = os.path.join(package_dir, self.metadata_json)
regular_file = os.path.join(self.build_lib, self.pkg_name, self.metadata_json)
self.copy_file(regular_file, inplace_file, level=self.verbose)

for filename in self.xformers_build_metadata.keys():
inplace_file = os.path.join(package_dir, filename)
regular_file = os.path.join(self.build_lib, self.pkg_name, filename)
self.copy_file(regular_file, inplace_file, level=self.verbose)
super().copy_extensions_to_source()


if __name__ == "__main__":

try:
# when installing as a source distribution, the version module should exist
# Let's import it manually to not trigger the load of the C++
# library - which does not exist yet, and creates a WARNING
spec = importlib.util.spec_from_file_location(
"xformers_version", os.path.join(this_dir, "xformers", "version.py")
)
if spec is None or spec.loader is None:
raise FileNotFoundError()
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
version = module.__version__
except FileNotFoundError:
if os.getenv("BUILD_VERSION"): # In CI
version = os.getenv("BUILD_VERSION", "0.0.0")
else:
version_txt = os.path.join(this_dir, "version.txt")
with open(version_txt) as f:
version = f.readline().strip()
version += get_local_version_suffix()
write_version_file(version)
if os.getenv("BUILD_VERSION"): # In CI
version = os.getenv("BUILD_VERSION", "0.0.0")
else:
version_txt = os.path.join(this_dir, "version.txt")
with open(version_txt) as f:
version = f.readline().strip()
version += get_local_version_suffix()

is_building_wheel = "bdist_wheel" in sys.argv
# Embed a fixed version of flash_attn
Expand All @@ -428,8 +399,12 @@ def copy_extensions_to_source(self):
packages=setuptools.find_packages(exclude=("tests*", "benchmarks*")),
ext_modules=extensions,
cmdclass={
"build_ext": BuildExtensionWithMetadata.with_options(
no_python_abi_suffix=True, xformers_build_metadata=extensions_metadata
"build_ext": BuildExtensionWithExtraFiles.with_options(
no_python_abi_suffix=True,
extra_files={
"cpp_lib.json": json.dumps(extensions_metadata),
"version.py": generate_version_py(version),
},
),
"clean": clean,
},
Expand Down

0 comments on commit b631a0e

Please sign in to comment.