Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a setting for clearing CUDA cache after each inference #2930

Merged
merged 2 commits into from
Jun 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions backend/src/api/node_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,11 @@ def storage_dir(self) -> Path:
"""

@abstractmethod
def add_cleanup(self, fn: Callable[[], None]) -> None:
def add_cleanup(
self, fn: Callable[[], None], after: Literal["node", "chain"] = "chain"
) -> None:
"""
Registers a function that will be called when the chain execution is finished.
Registers a function that will be called when the chain execution is finished (if set to chain mode) or after node execution is finished (node mode).

Registering the same function (object) twice will only result in the function being called once.
"""
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,11 @@ def align_image_to_reference_node(
alignment_passes: int,
blur_strength: float,
) -> np.ndarray:
context.add_cleanup(safe_cuda_cache_empty)
exec_options = get_settings(context)
context.add_cleanup(
safe_cuda_cache_empty,
after="node" if exec_options.force_cache_wipe else "chain",
)
multiplier = precision.value / 1000
return align_images(
context,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,15 @@ def guided_upscale_node(
iterations: float,
split_mode: SplitMode,
) -> np.ndarray:
context.add_cleanup(safe_cuda_cache_empty)
exec_options = get_settings(context)
context.add_cleanup(
safe_cuda_cache_empty,
after="node" if exec_options.force_cache_wipe else "chain",
)
return pix_transform_auto_split(
source=source,
guide=guide,
device=get_settings(context).device,
device=exec_options.device,
params=Params(iteration=int(iterations * 1000)),
split_mode=split_mode,
)
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,9 @@ def inpaint_node(
), "Input image and mask must have the same resolution"

exec_options = get_settings(context)

context.add_cleanup(safe_cuda_cache_empty)
context.add_cleanup(
safe_cuda_cache_empty,
after="node" if exec_options.force_cache_wipe else "chain",
)

return inpaint(img, mask, model, exec_options)
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,10 @@ def upscale_image_node(
) -> np.ndarray:
exec_options = get_settings(context)

context.add_cleanup(safe_cuda_cache_empty)
context.add_cleanup(
safe_cuda_cache_empty,
after="node" if exec_options.force_cache_wipe else "chain",
)

in_nc = model.input_channels
out_nc = model.output_channels
Expand Down Expand Up @@ -296,5 +299,4 @@ def inner_upscale(img: np.ndarray) -> np.ndarray:
if not use_custom_scale or scale == 1 or in_nc != out_nc:
# no custom scale
custom_scale = scale

return custom_scale_upscale(img, inner_upscale, scale, custom_scale, separate_alpha)
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,10 @@ def wavelet_color_fix_node(
)

exec_options = get_settings(context)
context.add_cleanup(safe_cuda_cache_empty)
context.add_cleanup(
safe_cuda_cache_empty,
after="node" if exec_options.force_cache_wipe else "chain",
)
device = exec_options.device

# convert to tensors
Expand Down
12 changes: 12 additions & 0 deletions backend/src/packages/chaiNNer_pytorch/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,24 @@
)
)

if nvidia.is_available:
package.add_setting(
ToggleSetting(
label="Force CUDA Cache Wipe (not recommended)",
key="force_cache_wipe",
description="Clears PyTorch's CUDA cache after each inference. This is NOT recommended, by us or PyTorch's developers, as it basically interferes with how PyTorch is intended to work and can significantly slow down inference time. Only enable this if you're experiencing issues with VRAM allocation.",
default=False,
)
)


@dataclass(frozen=True)
class PyTorchSettings:
use_cpu: bool
use_fp16: bool
gpu_index: int
budget_limit: int
force_cache_wipe: bool = False

# PyTorch 2.0 does not support FP16 when using CPU
def __post_init__(self):
Expand Down Expand Up @@ -122,4 +133,5 @@ def get_settings(context: NodeContext) -> PyTorchSettings:
use_fp16=settings.get_bool("use_fp16", False),
gpu_index=settings.get_int("gpu_index", 0, parse_str=True),
budget_limit=settings.get_int("budget_limit", 0, parse_str=True),
force_cache_wipe=settings.get_bool("force_cache_wipe", False),
)
26 changes: 21 additions & 5 deletions backend/src/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from contextlib import contextmanager
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, Iterable, List, NewType, Sequence, Union
from typing import Callable, Iterable, List, Literal, NewType, Sequence, Union

from sanic.log import logger

Expand Down Expand Up @@ -342,7 +342,8 @@ def __init__(
self.__settings = settings
self._storage_dir = storage_dir

self.cleanup_fns: set[Callable[[], None]] = set()
self.chain_cleanup_fns: set[Callable[[], None]] = set()
self.node_cleanup_fns: set[Callable[[], None]] = set()

@property
def aborted(self) -> bool:
Expand Down Expand Up @@ -373,8 +374,15 @@ def settings(self) -> SettingsParser:
def storage_dir(self) -> Path:
return self._storage_dir

def add_cleanup(self, fn: Callable[[], None]) -> None:
self.cleanup_fns.add(fn)
def add_cleanup(
self, fn: Callable[[], None], after: Literal["node", "chain"] = "chain"
) -> None:
if after == "chain":
self.chain_cleanup_fns.add(fn)
elif after == "node":
self.node_cleanup_fns.add(fn)
else:
raise ValueError(f"Unknown cleanup type: {after}")


class Executor:
Expand Down Expand Up @@ -591,6 +599,14 @@ def get_lazy_evaluation_time():
)
await self.progress.suspend()

for fn in context.node_cleanup_fns:
try:
fn()
except Exception as e:
logger.error(f"Error running cleanup function: {e}")
finally:
context.node_cleanup_fns.remove(fn)

lazy_time_after = get_lazy_evaluation_time()
execution_time -= lazy_time_after - lazy_time_before

Expand Down Expand Up @@ -824,7 +840,7 @@ async def __process_nodes(self):

# Run cleanup functions
for context in self.__context_cache.values():
for fn in context.cleanup_fns:
for fn in context.chain_cleanup_fns:
try:
fn()
except Exception as e:
Expand Down
Loading