Skip to content

Commit

Permalink
Tests for torch device and dtype utils (#2580)
Browse files Browse the repository at this point in the history
### Changes

This PR addresses #2579.
- Add tests for torch device utils. The tests consider the case in which
the model has no parameters, has all parameters on CPU, has all
parameters on CUDA, and has parameters placed on different devices. In
the latter, the parameters are moved on different devices randomly.
- Add tests for torch `dtype` utils. The case in which the model has no
parameters is also considered.
- Created a test torch helper class `EmptyModel` to take into account
the case in which the model has no parameters at all.
- Add docstrings in utils.

### Tests

I compared all the results manually checking for their correctness. The
code is also compliant with the coding style having been verified with
`pre-commit run`

---------

Co-authored-by: Daniil Lyakhov <[email protected]>
  • Loading branch information
DaniAffCH and daniil-lyakhov authored Mar 18, 2024
1 parent 86ea8f6 commit 71cbe90
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 20 deletions.
24 changes: 24 additions & 0 deletions nncf/torch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,14 @@ def maybe_convert_legacy_names_in_compress_state(compression_state: Dict[str, An


def get_model_device(model: torch.nn.Module) -> torch.device:
"""
Get the device on which the first model parameters reside.
:param model: The PyTorch model.
:return: The device where the first model parameter reside.
Default cpu if the model has no parameters.
"""

try:
device = next(model.parameters()).device
except StopIteration:
Expand All @@ -427,6 +435,14 @@ def get_all_model_devices_generator(model: torch.nn.Module) -> Generator[torch.d


def is_multidevice(model: torch.nn.Module) -> bool:
"""
Checks if the model's parameters are distributed across multiple devices.
:param model: The PyTorch model.
:return: True if the parameters reside on multiple devices, False otherwise.
Default False if the models has no parameters
"""

device_generator = get_all_model_devices_generator(model)
try:
curr_device = next(device_generator)
Expand All @@ -440,6 +456,14 @@ def is_multidevice(model: torch.nn.Module) -> bool:


def get_model_dtype(model: torch.nn.Module) -> torch.dtype:
"""
Get the datatype of the first model parameter.
:param model: The PyTorch model.
:return: The datatype of the first model parameter.
Default to torch.float32 if the model has no parameters.
"""

try:
dtype = next(model.parameters()).dtype
except StopIteration:
Expand Down
8 changes: 8 additions & 0 deletions tests/torch/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,14 @@ def forward(self, *input_, **kwargs):
return None


class EmptyModel(nn.Module):
def __init__(self):
super().__init__()

def forward(self, *input_, **kwargs):
return None


def check_correct_nncf_modules_replacement(
model: torch.nn.Module, compressed_model: NNCFNetwork
) -> Tuple[Dict[Scope, Module], Dict[Scope, Module]]:
Expand Down
76 changes: 56 additions & 20 deletions tests/torch/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,14 @@
from torch import nn

from nncf.torch.initialization import DataLoaderBNAdaptationRunner
from nncf.torch.layer_utils import CompressionParameter
from nncf.torch.utils import _ModuleState
from nncf.torch.utils import get_model_device
from nncf.torch.utils import get_model_dtype
from nncf.torch.utils import is_multidevice
from nncf.torch.utils import save_module_state
from nncf.torch.utils import training_mode_switcher
from tests.torch.helpers import BasicConvTestModel
from tests.torch.helpers import EmptyModel
from tests.torch.helpers import MockModel
from tests.torch.helpers import TwoConvTestModel
from tests.torch.quantization.test_overflow_issue_export import DepthWiseConvTestModel
Expand All @@ -33,29 +36,19 @@ def compare_saved_model_state_and_current_model_state(model: nn.Module, model_st
assert param.requires_grad == model_state.requires_grad_state[name]


def randomly_change_model_state(module: nn.Module, compression_params_only: bool = False):
import random
def change_model_state(module: nn.Module):
for i, ch in enumerate(module.modules()):
ch.training = i % 2 == 0

for ch in module.modules():
if random.uniform(0, 1) > 0.5:
ch.training = False
else:
ch.training = True

for p in module.parameters():
if compression_params_only and not (isinstance(p, CompressionParameter) and torch.is_floating_point(p)):
break
if random.uniform(0, 1) > 0.5:
p.requires_grad = False
else:
p.requires_grad = True
for i, p in enumerate(module.parameters()):
p.requires_grad = i % 2 == 0


@pytest.mark.parametrize(
"model", [BasicConvTestModel(), TwoConvTestModel(), MockModel(), DepthWiseConvTestModel(), EightConvTestModel()]
)
def test_training_mode_switcher(_seed, model: nn.Module):
randomly_change_model_state(model)
def test_training_mode_switcher(model: nn.Module):
change_model_state(model)
saved_state = save_module_state(model)
with training_mode_switcher(model, True):
pass
Expand All @@ -66,7 +59,7 @@ def test_training_mode_switcher(_seed, model: nn.Module):
@pytest.mark.parametrize(
"model", [BasicConvTestModel(), TwoConvTestModel(), MockModel(), DepthWiseConvTestModel(), EightConvTestModel()]
)
def test_bn_training_state_switcher(_seed, model: nn.Module):
def test_bn_training_state_switcher(model: nn.Module):
def check_were_only_bn_training_state_changed(model: nn.Module, saved_state: _ModuleState):
for name, module in model.named_modules():
if isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
Expand All @@ -76,10 +69,53 @@ def check_were_only_bn_training_state_changed(model: nn.Module, saved_state: _Mo

runner = DataLoaderBNAdaptationRunner(model, "cuda")

randomly_change_model_state(model)
for p in model.parameters():
p.requires_grad = False

saved_state = save_module_state(model)

with runner._bn_training_state_switcher():
check_were_only_bn_training_state_changed(model, saved_state)

compare_saved_model_state_and_current_model_state(model, saved_state)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="Cuda is not available in current environment")
def test_model_device():
model = TwoConvTestModel()
cuda = torch.device("cuda")

assert not is_multidevice(model)
assert get_model_device(model).type == "cpu"

model.features[0][0].to(cuda)

assert is_multidevice(model)
assert get_model_device(model).type == "cuda"

model.to(cuda)

assert not is_multidevice(model)
assert get_model_device(model).type == "cuda"


def test_empty_model_device():
model = EmptyModel()

assert not is_multidevice(model)
assert get_model_device(model).type == "cpu"


def test_model_dtype():
model = BasicConvTestModel()
model.to(torch.float16)
assert get_model_dtype(model) == torch.float16
model.to(torch.float32)
assert get_model_dtype(model) == torch.float32
model.to(torch.float64)
assert get_model_dtype(model) == torch.float64


def test_empty_model_dtype():
model = EmptyModel()
assert get_model_dtype(model) == torch.float32

0 comments on commit 71cbe90

Please sign in to comment.