From 006718255cd35f18eae6cc90b29695e205c5a316 Mon Sep 17 00:00:00 2001 From: raresgaia123 <137071040+raresgaia123@users.noreply.github.com> Date: Thu, 29 Aug 2024 21:54:15 +0300 Subject: [PATCH] refactor: run post_setup script on init (#87) made changes to that post_setup script is executed in __init__.py of multiworld error handlers were while executing the main method. with the first execution, the patch is applied and the user is prompted to re-run the script to have the patched version of torch Co-authored-by: Rares Gaia --- multiworld/__init__.py | 4 +++ multiworld/__main__.py | 20 --------------- multiworld/init.txt | 1 + multiworld/post_setup.py | 54 ++++++++++++++++++++++------------------ pyproject.toml | 5 +--- 5 files changed, 36 insertions(+), 48 deletions(-) delete mode 100644 multiworld/__main__.py create mode 100644 multiworld/init.txt diff --git a/multiworld/__init__.py b/multiworld/__init__.py index 3440c3e..a3aa150 100644 --- a/multiworld/__init__.py +++ b/multiworld/__init__.py @@ -22,8 +22,12 @@ from multiworld.version import VERSION as __version__ # noqa: F401 +from . import post_setup + logging.basicConfig( level=getattr(logging, os.getenv("M8D_LOG_LEVEL", "WARNING")), format="%(asctime)s | %(filename)s:%(lineno)d | %(levelname)s | %(threadName)s | %(funcName)s | %(message)s", handlers=[logging.StreamHandler(sys.stdout)], ) + +post_setup.configure_once() diff --git a/multiworld/__main__.py b/multiworld/__main__.py deleted file mode 100644 index ce0ad8f..0000000 --- a/multiworld/__main__.py +++ /dev/null @@ -1,20 +0,0 @@ -# Copyright 2024 Cisco Systems, Inc. and its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# SPDX-License-Identifier: Apache-2.0 - -"""__main__.py.""" -from multiworld.setup import main - -main() diff --git a/multiworld/init.txt b/multiworld/init.txt new file mode 100644 index 0000000..02e4a84 --- /dev/null +++ b/multiworld/init.txt @@ -0,0 +1 @@ +false \ No newline at end of file diff --git a/multiworld/post_setup.py b/multiworld/post_setup.py index 08adf47..b3ccdd3 100644 --- a/multiworld/post_setup.py +++ b/multiworld/post_setup.py @@ -14,36 +14,39 @@ # # SPDX-License-Identifier: Apache-2.0 -import argparse import os import pathlib import shutil import site +import sys -import torch +def configure_once(): + package_name = __name__.split(".")[0] + path_to_sitepackages = site.getsitepackages()[0] -def main(): - parser = argparse.ArgumentParser() - parser.add_argument( - "patchfile", nargs="?", default=None, help="Path to the patch file" - ) - args = parser.parse_args() + init_file_path = os.path.join(path_to_sitepackages, package_name, "init.txt") - path_to_sitepackages = site.getsitepackages()[0] + with open(init_file_path, "r") as file: + patch_applied = file.read() + + if patch_applied == "true": + return + + print(f"Configuring {package_name} for the first time. This is one time task.") - if args.patchfile: - patchfile = args.patchfile - else: - torch_version = torch.__version__.split("+")[ - 0 - ] # torch version is in "2.2.1+cu121" format - patchfile = os.path.join( - path_to_sitepackages, - "multiworld", - "patch", - "pytorch-v" + torch_version + ".patch", - ) + import torch + + torch_version = torch.__version__.split("+")[ + 0 + ] # torch version is in "2.2.1+cu121" format + + patchfile = os.path.join( + path_to_sitepackages, + package_name, + "patch", + "pytorch-v" + torch_version + ".patch", + ) patch_basename = os.path.basename(patchfile) @@ -52,10 +55,13 @@ def main(): os.chdir(path_to_sitepackages) - os.system(f"patch -p1 < {patch_basename}") + os.system(f"patch -p1 < {patch_basename} > /dev/null") p = pathlib.Path(patch_basename) p.unlink() + with open(init_file_path, "w") as file: + file.write("true") -if __name__ == "__main__": - main() + sys.exit( + f"This one-time configuration for {package_name} is completed.\nYou can run your script without any interruption from now on." + ) diff --git a/pyproject.toml b/pyproject.toml index 8315e66..dbd35ac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,9 +13,6 @@ dependencies = [ "torch == 2.4.0", ] -[project.scripts] -m8d-post-setup = "multiworld.post_setup:main" - [Project.optional-dependencies] dev = [ "black", @@ -30,7 +27,7 @@ dev = [ packages=["multiworld"] [tool.setuptools.package-data] -"multiworld" = ["patch/*.patch"] +"multiworld" = ["patch/*.patch", "init.txt"] [tool.setuptools.dynamic] version = {attr = "multiworld.__version__"}