From f9ce0ae287c3440e76bfe4bd71f27bfc42174d26 Mon Sep 17 00:00:00 2001 From: morgoth95 Date: Tue, 10 May 2022 11:19:09 +0200 Subject: [PATCH] fix: fix bug with cuda devices in clip-nebullvm --- server/clip_server/model/clip_nebullvm.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/server/clip_server/model/clip_nebullvm.py b/server/clip_server/model/clip_nebullvm.py index 1439256c5..233084b95 100644 --- a/server/clip_server/model/clip_nebullvm.py +++ b/server/clip_server/model/clip_nebullvm.py @@ -77,17 +77,21 @@ class EnvRunner: def __init__(self, device: str, num_threads: int = None): self.device = device self.cuda_str = None + self.rm_cuda_flag = False self.num_threads = num_threads def __enter__(self): if self.device == "cpu" and torch.cuda.is_available(): - self.cuda_str = os.environ.get("CUDA_VISIBLE_DEVICES") or "1" + self.cuda_str = os.environ.get("CUDA_VISIBLE_DEVICES") os.environ["CUDA_VISIBLE_DEVICES"] = "0" + self.rm_cuda_flag = self.cuda_str is None if self.num_threads is not None: os.environ["NEBULLVM_THREADS_PER_MODEL"] = f"{self.num_threads}" def __exit__(self, exc_type, exc_val, exc_tb): if self.cuda_str is not None: os.environ["CUDA_VISIBLE_DEVICES"] = self.cuda_str + elif self.rm_cuda_flag: + os.environ.pop("CUDA_VISIBLE_DEVICES") if self.num_threads is not None: os.environ.pop("NEBULLVM_THREADS_PER_MODEL")