diff --git a/.gitignore b/.gitignore index b6e4761..064966e 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,8 @@ +data/* +predictions/* + +!.placeholder + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/esc/predict.py b/esc/predict.py index 6e74b6b..f21a312 100644 --- a/esc/predict.py +++ b/esc/predict.py @@ -305,6 +305,7 @@ def parse_args() -> argparse.Namespace: parser.add_argument("--prediction-types", type=str, required=True, nargs="+") parser.add_argument("--evaluate", action="store_true", default=False) # default + not required + parser.add_argument("--cpu", action="store_true", default=False) parser.add_argument("--device", type=int, default=0) parser.add_argument("--tokens-per-batch", type=int, default=4000) parser.add_argument("--output-errors", action="store_true", default=False) @@ -328,7 +329,9 @@ def main() -> None: wsd_model = ESCModule.load_from_checkpoint(args.ckpt) wsd_model.freeze() - if args.device >= 0: + if args.cpu: + wsd_model.to("cpu") + elif args.device >= 0: wsd_model.to(torch.device(args.device)) tokenizer = get_tokenizer( diff --git a/esc/utils/commons.py b/esc/utils/commons.py index 7f6638c..6623f8b 100644 --- a/esc/utils/commons.py +++ b/esc/utils/commons.py @@ -27,6 +27,9 @@ def chunks(lst, n): def list_elems_in_dir(dir_path: str, only_files: bool = False, only_dirs: bool = False) -> List[str]: + if not isdir(dir_path): + return list() + elems_in_dir = [e for e in listdir(dir_path)] if only_files: diff --git a/requirements.txt b/requirements.txt index 574b9ab..c2d2067 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,4 +3,5 @@ pytorch-lightning==0.9.0 wandb nltk==3.4.5 nlp -black==21.5b2 \ No newline at end of file +black==21.5b2 +protobuf<=3.20.1 \ No newline at end of file