Skip to content

Commit

Permalink
Add install.py for test_bench (#2546)
Browse files Browse the repository at this point in the history
Summary:
Add install.py for test_bench introduced in #2052

### Usage

```bash
$ python install.py --userbenchmark test_bench --models BERT_pytorch hf_GPT2 --skip hf_GPT2
checking packages numpy, torch are installed, generating constaints...OK
Installing userbenchmark test_bench with extra args: ['--models']
Installing BERT_pytorch...
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Obtaining file:///Users/shenke/workspace/benchmark/torchbenchmark/models/BERT_pytorch
  Preparing metadata (setup.py) ... done
Requirement already satisfied: tqdm in /opt/anaconda3/envs/nightlypth/lib/python3.11/site-packages (from bert_pytorch==0.0.1a4) (4.66.4)
Requirement already satisfied: numpy in /opt/anaconda3/envs/nightlypth/lib/python3.11/site-packages (from bert_pytorch==0.0.1a4) (1.24.4)
Installing collected packages: bert_pytorch
  Attempting uninstall: bert_pytorch
    Found existing installation: bert_pytorch 0.0.1a4
    Uninstalling bert_pytorch-0.0.1a4:
      Successfully uninstalled bert_pytorch-0.0.1a4
  Running setup.py develop for bert_pytorch
Successfully installed bert_pytorch-0.0.1a4
```

Pull Request resolved: #2546

Reviewed By: xuzhao9

Differential Revision: D66457458

Pulled By: FindHao

fbshipit-source-id: 73b5f88dc50bd27eceb91c456279d7a687656c7c
  • Loading branch information
shink authored and facebook-github-bot committed Nov 25, 2024
1 parent 341ad14 commit 9ff23ab
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 6 deletions.
6 changes: 6 additions & 0 deletions install.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,12 @@
print(
f"Installing userbenchmark {args.userbenchmark} with extra args: {extra_args}"
)
if args.models:
cmd.extend(["--models"] + args.models)
if args.skip:
cmd.extend(["--skip"] + args.skip)
if args.canary:
cmd.extend(["--canary"])
cmd.extend(extra_args)
if userbenchmark_dir.joinpath("install.py").is_file():
# add the current run env to PYTHONPATH to load framework install utils
Expand Down
21 changes: 15 additions & 6 deletions torchbenchmark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,15 +172,11 @@ def _is_canary_model(model_name: str) -> bool:
return False


def setup(
def _filter_model_paths(
models: Optional[List[str]] = None,
skip_models: Optional[List[str]] = None,
verbose: bool = True,
continue_on_fail: bool = False,
test_mode: bool = False,
allow_canary: bool = False,
) -> bool:
failures = {}
) -> List[str]:
models = list(map(lambda p: p.lower(), models))
model_paths = filter(
lambda p: True if not models else os.path.basename(p).lower() in models,
Expand All @@ -195,6 +191,19 @@ def setup(
model_paths.extend(canary_model_paths)
skip_models = [] if not skip_models else skip_models
model_paths = [x for x in model_paths if os.path.basename(x) not in skip_models]
return model_paths


def setup(
models: Optional[List[str]] = None,
skip_models: Optional[List[str]] = None,
verbose: bool = True,
continue_on_fail: bool = False,
test_mode: bool = False,
allow_canary: bool = False,
) -> bool:
failures = {}
model_paths = _filter_model_paths(models, skip_models, allow_canary)
for model_path in model_paths:
print(f"running setup for {model_path}...", end="", flush=True)
if test_mode:
Expand Down
46 changes: 46 additions & 0 deletions userbenchmark/test_bench/install.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import argparse
import os
import subprocess
import sys

parser = argparse.ArgumentParser(allow_abbrev=False)
parser.add_argument(
"models",
nargs="*",
default=[],
help="Specify one or more models to install. If not set, install all models.",
)
parser.add_argument("--skip", nargs="*", default=[], help="Skip models to install.")
parser.add_argument("--canary", action="store_true", help="Install canary model.")
args, extra_args = parser.parse_known_args()


def install_test_bench_requirements():
from torchbenchmark import _filter_model_paths

model_paths = _filter_model_paths(args.models, args.skip, args.canary)
for path in model_paths:
print(f"Installing {os.path.basename(path)}...")
install_py_path = os.path.join(path, "install.py")
requirements_txt_path = os.path.join(path, "requirements.txt")
if os.path.exists(install_py_path):
subprocess.check_call([sys.executable, install_py_path], cwd=path)
elif os.path.exists(requirements_txt_path):
subprocess.check_call(
[
sys.executable,
"-m",
"pip",
"install",
"-q",
"-r",
f"{requirements_txt_path}",
],
cwd=path,
)
else:
print(f"SKipped: {os.path.basename(path)}")


if __name__ == "__main__":
install_test_bench_requirements()

0 comments on commit 9ff23ab

Please sign in to comment.