Skip to content

Commit

Permalink
Feature/sg 1198 mixed precision automatically changed with warning (#…
Browse files Browse the repository at this point in the history
…1567)

* fix

* work with tmpdir

* minor change of comment

* improve device_config
  • Loading branch information
Louis-Dupont authored Oct 25, 2023
1 parent ec21383 commit 34fda6c
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 3 deletions.
16 changes: 15 additions & 1 deletion src/super_gradients/common/environment/device_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,24 @@ def _get_assigned_rank() -> int:

@dataclasses.dataclass
class DeviceConfig:
device: str = "cuda" if torch.cuda.is_available() else "cpu"
_device: str = "cuda" if torch.cuda.is_available() else "cpu"
multi_gpu: str = None
assigned_rank: int = dataclasses.field(default=_get_assigned_rank(), init=False)

@property
def device(self) -> str:
return self._device

@device.setter
def device(self, value: str):
if "cuda" in value and not torch.cuda.is_available():
raise ValueError("CUDA is not available, cannot set device to cuda")
self._device = value

@property
def is_cuda(self):
return "cuda" in self._device


# Singleton holding the device information
device_config = DeviceConfig()
11 changes: 9 additions & 2 deletions src/super_gradients/training/sg_trainer/sg_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import inspect
import os
import typing
import warnings
from copy import deepcopy
from typing import Union, Tuple, Mapping, Dict, Any, List, Optional

Expand Down Expand Up @@ -1331,7 +1332,7 @@ def forward(self, inputs, targets):

self.pre_prediction_callback = CallbacksFactory().get(self.training_params.pre_prediction_callback)

self._initialize_mixed_precision(self.training_params.mixed_precision)
self.training_params.mixed_precision = self._initialize_mixed_precision(self.training_params.mixed_precision)

self.ckpt_best_name = self.training_params.ckpt_best_name

Expand Down Expand Up @@ -1601,11 +1602,16 @@ def _set_test_metrics(self, test_metrics_list):
self.test_metrics = MetricCollection(test_metrics_list)

def _initialize_mixed_precision(self, mixed_precision_enabled: bool):

if mixed_precision_enabled and not device_config.is_cuda:
warnings.warn("Mixed precision training is not supported on CPU. Disabling mixed precision. (i.e. `mixed_precision=False`)")
mixed_precision_enabled = False

# SCALER IS ALWAYS INITIALIZED BUT IS DISABLED IF MIXED PRECISION WAS NOT SET
self.scaler = GradScaler(enabled=mixed_precision_enabled)

if mixed_precision_enabled:
assert device_config.device.startswith("cuda"), "mixed precision is not available for CPU"

if device_config.multi_gpu == MultiGPUMode.DATA_PARALLEL:
# IN DATAPARALLEL MODE WE NEED TO WRAP THE FORWARD FUNCTION OF OUR MODEL SO IT WILL RUN WITH AUTOCAST.
# BUT SINCE THE MODULE IS CLONED TO THE DEVICES ON EACH FORWARD CALL OF A DATAPARALLEL MODEL,
Expand All @@ -1621,6 +1627,7 @@ def hook(module, _):
logger.warning("Mixed Precision - scaler state_dict not found in loaded model. This may case issues " "with loss scaling")
else:
self.scaler.load_state_dict(scaler_state_dict)
return mixed_precision_enabled

def _validate_final_average_model(self, context: PhaseContext, checkpoint_dir_path: str, cleanup_snapshots_pkl_file=False):
"""
Expand Down
2 changes: 2 additions & 0 deletions tests/deci_core_unit_test_suite_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
TestPostPredictionCallback,
TestModelPredict,
TestDeprecationDecorator,
TestMixedPrecisionDisabled,
)
from tests.end_to_end_tests import TestTrainer
from tests.unit_tests.detection_utils_test import TestDetectionUtils
Expand Down Expand Up @@ -162,6 +163,7 @@ def _add_modules_to_unit_tests_suite(self):
self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestPoseEstimationModelExport))
self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(YoloNASPoseTests))
self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(PoseEstimationSampleTest))
self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestMixedPrecisionDisabled))

def _add_modules_to_end_to_end_tests_suite(self):
"""
Expand Down
2 changes: 2 additions & 0 deletions tests/unit_tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from tests.unit_tests.post_prediction_callback_test import TestPostPredictionCallback
from tests.unit_tests.test_predict import TestModelPredict
from tests.unit_tests.test_deprecate import TestDeprecationDecorator
from tests.unit_tests.test_mixed_precision_cpu import TestMixedPrecisionDisabled

__all__ = [
"CrashTipTest",
Expand Down Expand Up @@ -55,4 +56,5 @@
"TestPostPredictionCallback",
"TestModelPredict",
"TestDeprecationDecorator",
"TestMixedPrecisionDisabled",
]
51 changes: 51 additions & 0 deletions tests/unit_tests/test_mixed_precision_cpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import unittest
import tempfile

from super_gradients import Trainer
from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
from super_gradients.training.metrics import Accuracy, Top5
from super_gradients.training.models import ResNet18
from super_gradients.training.utils.distributed_training_utils import setup_device


class TestMixedPrecisionDisabled(unittest.TestCase):
def test_mixed_precision_automatically_changed_with_warning(self):
setup_device(device="cpu")

with tempfile.TemporaryDirectory() as temp_dir:
trainer = Trainer("test_mixed_precision_automatically_changed_with_warning", ckpt_root_dir=temp_dir)
net = ResNet18(num_classes=5, arch_params={})
train_params = {
"max_epochs": 2,
"lr_updates": [1],
"lr_decay_factor": 0.1,
"lr_mode": "StepLRScheduler",
"lr_warmup_epochs": 0,
"initial_lr": 0.1,
"loss": "CrossEntropyLoss",
"criterion_params": {"ignore_index": 0},
"train_metrics_list": [Accuracy(), Top5()],
"valid_metrics_list": [Accuracy(), Top5()],
"metric_to_watch": "Accuracy",
"greater_metric_to_watch_is_better": True,
"mixed_precision": True, # This is not supported for CPU, so we expect a warning to be raised AND the code to run
}
import warnings

with warnings.catch_warnings(record=True) as w:
# Trigger a filter to always make warnings visible
warnings.simplefilter("always")

trainer.train(
model=net,
training_params=train_params,
train_loader=classification_test_dataloader(batch_size=10),
valid_loader=classification_test_dataloader(batch_size=10),
)

# Check if the desired warning is in the list of warnings
self.assertTrue(any("Mixed precision training is not supported on CPU" in str(warn.message) for warn in w))


if __name__ == "__main__":
unittest.main()

0 comments on commit 34fda6c

Please sign in to comment.