forked from Project-MONAI/MONAI
-
Notifications
You must be signed in to change notification settings - Fork 0
/
nvtx.py
176 lines (144 loc) · 6.71 KB
/
nvtx.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Decorators and context managers for NVIDIA Tools Extension to profile MONAI components
"""
from __future__ import annotations
from collections import defaultdict
from functools import wraps
from typing import Any
from torch.autograd import Function
from torch.nn import Module
from torch.optim import Optimizer
from torch.utils.data import Dataset
from monai.utils import ensure_tuple, optional_import
_nvtx, _ = optional_import("torch._C._nvtx", descriptor="NVTX is not installed. Are you sure you have a CUDA build?")
__all__ = ["Range"]
class Range:
"""
A decorator and context manager for NVIDIA Tools Extension (NVTX) Range for profiling.
When used as a decorator it encloses a specific method of the object with an NVTX Range.
When used as a context manager, it encloses the runtime context (created by with statement) with an NVTX Range.
Args:
name: the name to be associated to the range
methods: (only when used as decorator) the name of a method (or a list of the name of the methods)
to be wrapped by NVTX range.
If None (default), the method(s) will be inferred based on the object's type for various MONAI components,
such as Networks, Losses, Functions, Transforms, and Datasets.
Otherwise, it look up predefined methods: "forward", "__call__", "__next__", "__getitem__"
append_method_name: if append the name of the methods to be decorated to the range's name
If None (default), it appends the method's name only if we are annotating more than one method.
recursive: if set to True, it will recursively annotate every individual module in a list
or in a chain of modules (chained using Compose). Default to False.
"""
name_counter: dict = defaultdict(int)
def __init__(
self,
name: str | None = None,
methods: str | tuple[str, ...] | None = None,
append_method_name: bool | None = None,
recursive: bool = False,
) -> None:
self.name = name
self.methods = methods
self.append_method_name = append_method_name
self.recursive = recursive
def __call__(self, obj: Any) -> Any:
if self.recursive is True:
if isinstance(obj, (list, tuple)):
return type(obj)(Range(recursive=True)(t) for t in obj)
from monai.transforms.compose import Compose
if isinstance(obj, Compose):
obj.transforms = Range(recursive=True)(obj.transforms)
self.recursive = False
# Define the name to be associated to the range if not provided
if self.name is None:
name = type(obj).__name__
# If CuCIM or TorchVision transform wrappers are being used,
# append the underlying transform to the name for more clarity
if "CuCIM" in name or "TorchVision" in name:
name = f"{name}_{obj.name}"
self.name_counter[name] += 1
if self.name_counter[name] > 1:
self.name = f"{name}_{self.name_counter[name]}"
else:
self.name = name
# Define the methods to be wrapped if not provided
if self.methods is None:
self.methods = self._get_method(obj)
else:
self.methods = ensure_tuple(self.methods)
# Check if to append method's name to the range's name
if self.append_method_name is None:
if len(self.methods) > 1:
self.append_method_name = True
else:
self.append_method_name = False
# Decorate the methods
for method in self.methods:
self._decorate_method(obj, method, self.append_method_name)
return obj
def _decorate_method(self, obj, method, append_method_name):
# Append the method's name to the range's name
name = f"{self.name}.{method}" if append_method_name else self.name
# Get the class for special functions
if method.startswith("__"):
owner = type(obj)
else:
owner = obj
# Get the method to be wrapped
_temp_func = getattr(owner, method)
# Wrap the method with NVTX range (range push/pop)
@wraps(_temp_func)
def range_wrapper(*args, **kwargs):
_nvtx.rangePushA(name)
output = _temp_func(*args, **kwargs)
_nvtx.rangePop()
return output
# Replace the method with the wrapped version
if method.startswith("__"):
# If it is a special method, it requires special attention
class NVTXRangeDecoratedClass(owner): # type: ignore
...
setattr(NVTXRangeDecoratedClass, method, range_wrapper)
obj.__class__ = NVTXRangeDecoratedClass
else:
setattr(owner, method, range_wrapper)
def _get_method(self, obj: Any) -> tuple:
if isinstance(obj, Module):
method_list = ["forward"]
elif isinstance(obj, Optimizer):
method_list = ["step"]
elif isinstance(obj, Function):
method_list = ["forward", "backward"]
elif isinstance(obj, Dataset):
method_list = ["__getitem__"]
else:
default_methods = ["forward", "__call__", "__next__", "__getitem__"]
method_list = []
for method in default_methods:
if hasattr(obj, method):
method_list.append(method)
if len(method_list) < 1:
raise ValueError(
f"The method to be wrapped for this object [{type(obj)}] is not recognized."
"The name of the method should be provided or the object should have one of these methods:"
f"{default_methods}"
)
return ensure_tuple(method_list)
def __enter__(self):
if self.name is None:
# Number the range with class variable counter to avoid duplicate names.
self.name_counter["context"] += 1
self.name = f"context_{self.name_counter['context']}"
_nvtx.rangePushA(self.name)
def __exit__(self, type, value, traceback):
_nvtx.rangePop()