Skip to content

Commit

Permalink
refactor: run post_setup script on init (#87)
Browse files Browse the repository at this point in the history
made changes to that post_setup script is executed in __init__.py of multiworld
error handlers were while executing the main method.
with the first execution, the patch is applied and the user is prompted to re-run the script to have the patched version of torch

Co-authored-by: Rares Gaia <[email protected]>
  • Loading branch information
raresgaia123 and Rares Gaia authored Aug 29, 2024
1 parent 3067cd5 commit 0067182
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 48 deletions.
4 changes: 4 additions & 0 deletions multiworld/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,12 @@

from multiworld.version import VERSION as __version__ # noqa: F401

from . import post_setup

logging.basicConfig(
level=getattr(logging, os.getenv("M8D_LOG_LEVEL", "WARNING")),
format="%(asctime)s | %(filename)s:%(lineno)d | %(levelname)s | %(threadName)s | %(funcName)s | %(message)s",
handlers=[logging.StreamHandler(sys.stdout)],
)

post_setup.configure_once()
20 changes: 0 additions & 20 deletions multiworld/__main__.py

This file was deleted.

1 change: 1 addition & 0 deletions multiworld/init.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
false
54 changes: 30 additions & 24 deletions multiworld/post_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,36 +14,39 @@
#
# SPDX-License-Identifier: Apache-2.0

import argparse
import os
import pathlib
import shutil
import site
import sys

import torch

def configure_once():
package_name = __name__.split(".")[0]
path_to_sitepackages = site.getsitepackages()[0]

def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"patchfile", nargs="?", default=None, help="Path to the patch file"
)
args = parser.parse_args()
init_file_path = os.path.join(path_to_sitepackages, package_name, "init.txt")

path_to_sitepackages = site.getsitepackages()[0]
with open(init_file_path, "r") as file:
patch_applied = file.read()

if patch_applied == "true":
return

print(f"Configuring {package_name} for the first time. This is one time task.")

if args.patchfile:
patchfile = args.patchfile
else:
torch_version = torch.__version__.split("+")[
0
] # torch version is in "2.2.1+cu121" format
patchfile = os.path.join(
path_to_sitepackages,
"multiworld",
"patch",
"pytorch-v" + torch_version + ".patch",
)
import torch

torch_version = torch.__version__.split("+")[
0
] # torch version is in "2.2.1+cu121" format

patchfile = os.path.join(
path_to_sitepackages,
package_name,
"patch",
"pytorch-v" + torch_version + ".patch",
)

patch_basename = os.path.basename(patchfile)

Expand All @@ -52,10 +55,13 @@ def main():

os.chdir(path_to_sitepackages)

os.system(f"patch -p1 < {patch_basename}")
os.system(f"patch -p1 < {patch_basename} > /dev/null")
p = pathlib.Path(patch_basename)
p.unlink()

with open(init_file_path, "w") as file:
file.write("true")

if __name__ == "__main__":
main()
sys.exit(
f"This one-time configuration for {package_name} is completed.\nYou can run your script without any interruption from now on."
)
5 changes: 1 addition & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,6 @@ dependencies = [
"torch == 2.4.0",
]

[project.scripts]
m8d-post-setup = "multiworld.post_setup:main"

[Project.optional-dependencies]
dev = [
"black",
Expand All @@ -30,7 +27,7 @@ dev = [
packages=["multiworld"]

[tool.setuptools.package-data]
"multiworld" = ["patch/*.patch"]
"multiworld" = ["patch/*.patch", "init.txt"]

[tool.setuptools.dynamic]
version = {attr = "multiworld.__version__"}

0 comments on commit 0067182

Please sign in to comment.