Skip to content

Commit

Permalink
cannot load the best checkpoint which is empty (openvinotoolkit#1187)
Browse files Browse the repository at this point in the history
* add setting the best checkpoint before training

* remove unused parameter

* fix for better readibility

* fix when best_checkpoints is empty at AdaptiveCompressionTrainingLoop
  • Loading branch information
sungchul2 authored May 26, 2022
1 parent 4e628c0 commit 7f0f48a
Show file tree
Hide file tree
Showing 3 changed files with 147 additions and 14 deletions.
2 changes: 1 addition & 1 deletion nncf/common/accuracy_aware_training/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@ def _run_early_exit_training_loop(self, model):
rel_accuracy_drop = self._calculate_rel_accuracy_drop(uncompressed_model_accuracy,
compressed_model_accuracy)

self.runner.dump_statistics(model, self.compression_controller)
if self._accuracy_criterion_satisfied(accuracy_budget, self.compression_controller):
self.runner.dump_statistics(model, self.compression_controller)
nncf_logger.info('The accuracy criteria is reached after the initialization step.')
self.print_accuracy_statistics(uncompressed_model_accuracy, compressed_model_accuracy,
accuracy_drop, rel_accuracy_drop, accuracy_budget)
Expand Down
29 changes: 16 additions & 13 deletions nncf/torch/accuracy_aware_training/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,19 +254,22 @@ def _save_best_checkpoint(self, checkpoint_path):
copyfile(checkpoint_path, best_path)

def load_best_checkpoint(self, model):
# load checkpoint with highest compression rate and positive acc budget
possible_checkpoint_rates = self.get_compression_rates_with_positive_acc_budget()
if not possible_checkpoint_rates:
nncf_logger.warning('Could not produce a compressed model satisfying the set accuracy '
'degradation criterion during training. Increasing the number of training '
'epochs')
best_checkpoint_compression_rate = sorted(possible_checkpoint_rates)[-1]
resuming_checkpoint_path = self._best_checkpoints[best_checkpoint_compression_rate]
nncf_logger.info('Loading the best checkpoint found during training '
'{}...'.format(resuming_checkpoint_path))
resuming_checkpoint = torch.load(resuming_checkpoint_path, map_location='cpu')
resuming_model_state_dict = resuming_checkpoint.get('state_dict', resuming_checkpoint)
load_state(model, resuming_model_state_dict, is_resume=True)
if len(self._best_checkpoints) > 0:
# load checkpoint with highest compression rate and positive acc budget
possible_checkpoint_rates = self.get_compression_rates_with_positive_acc_budget()
if not possible_checkpoint_rates:
nncf_logger.warning('Could not produce a compressed model satisfying the set accuracy '
'degradation criterion during training. Increasing the number of training '
'epochs')
best_checkpoint_compression_rate = sorted(possible_checkpoint_rates)[-1]
resuming_checkpoint_path = self._best_checkpoints[best_checkpoint_compression_rate]
nncf_logger.info('Loading the best checkpoint found during training '
'{}...'.format(resuming_checkpoint_path))
resuming_checkpoint = torch.load(resuming_checkpoint_path, map_location='cpu')
resuming_model_state_dict = resuming_checkpoint.get('state_dict', resuming_checkpoint)
load_state(model, resuming_model_state_dict, is_resume=True)
else:
nncf_logger.info('The best checkpoint has not been set yet. Return the last checkpoint...')

@property
def compressed_training_history(self):
Expand Down
130 changes: 130 additions & 0 deletions tests/torch/accuracy_aware_training/test_training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,79 @@ def configure_optimizers_fn():
assert validate_fn(model, train_loader=train_loader) == pytest.approx(reference_final_metric, 1e-4)


@pytest.mark.parametrize(
('max_accuracy_degradation', 'maximal_total_epochs', 'initial_compression_rate_step'),
(({'maximal_absolute_accuracy_degradation': 0.1}, 1, 0.01),)
)
def test_adaptive_compression_training_loop_with_no_training(
max_accuracy_degradation,
maximal_total_epochs,
initial_compression_rate_step,
learning_rate=1e-3,
initial_training_phase_epochs=1,
patience_epochs=3,
):
"""When conditions below for adaptive compression training is not satisfied in the loop.
- self.runner.compression_rate_step >= self.runner.minimal_compression_rate_step
- self.runner.cumulative_epoch_count < self.runner.maximal_total_epochs
"""

def mock_validate_fn(model, init_step=False, epoch=0):
original_metric = 0.85
if init_step:
return original_metric

return original_metric - 0.04 * epoch

input_sample_size = [1, 1, LeNet.INPUT_SIZE[-1], LeNet.INPUT_SIZE[-1]]
config = get_basic_magnitude_sparsity_config(input_sample_size=input_sample_size)

params = {
"initial_training_phase_epochs": initial_training_phase_epochs,
"patience_epochs": patience_epochs,
"maximal_total_epochs": maximal_total_epochs,
"initial_compression_rate_step": initial_compression_rate_step
}
params.update(max_accuracy_degradation)
accuracy_aware_config = {
"accuracy_aware_training": {
"mode": "adaptive_compression_level",
"params": params
}
}

config.update(accuracy_aware_config)

train_loader = create_ones_mock_dataloader(config, num_samples=10)
model = LeNet()

config = register_default_init_args(config,
train_loader=train_loader,
model_eval_fn=partial(mock_validate_fn, init_step=True))

model, compression_ctrl = create_compressed_model_and_algo_for_test(model, config)

def train_fn(compression_ctrl, model, optimizer,
train_loader=train_loader, **kwargs):
pass

def configure_optimizers_fn():
optimizer = SGD(model.parameters(), lr=learning_rate)
return optimizer, None

acc_aware_training_loop = AdaptiveCompressionTrainingLoop(config, compression_ctrl)

model = acc_aware_training_loop.run(model,
train_epoch_fn=train_fn,
validate_fn=partial(mock_validate_fn, init_step=False),
configure_optimizers_fn=configure_optimizers_fn)
assert len(acc_aware_training_loop.runner._best_checkpoints) == 0

possible_checkpoint_compression_rates = \
acc_aware_training_loop.runner.get_compression_rates_with_positive_acc_budget()
assert len(possible_checkpoint_compression_rates) == 1


@pytest.mark.parametrize(
'max_accuracy_degradation',
(({'maximal_relative_accuracy_degradation': 30.0}), ({'maximal_relative_accuracy_degradation': 1.0}),
Expand Down Expand Up @@ -267,6 +340,63 @@ def configure_optimizers_fn():
assert epoch_counter == exit_epoch_number


@pytest.mark.parametrize(
('max_accuracy_degradation'),
(({'maximal_absolute_accuracy_degradation': 0.1}),)
)
def test_early_exit_with_mock_validation_and_no_improvement(
max_accuracy_degradation, maximal_total_epochs=5
):
def mock_validate_fn(model, init_step=False, epoch=0):
original_metric = 0.85
if init_step:
return original_metric

return original_metric - 0.11 * (epoch+1)

config = get_quantization_config_without_range_init(LeNet.INPUT_SIZE[-1])

params = {
"maximal_total_epochs": maximal_total_epochs
}
params.update(max_accuracy_degradation)
accuracy_aware_config = {
"accuracy_aware_training": {
"mode": "early_exit",
"params": params
}
}

config.update(accuracy_aware_config)

train_loader = create_ones_mock_dataloader(config, num_samples=10)
model = LeNet()

config = register_default_init_args(config,
train_loader=train_loader,
model_eval_fn=partial(mock_validate_fn, init_step=True))

model, compression_ctrl = create_compressed_model_and_algo_for_test(model, config)

def train_fn(compression_ctrl, model, epoch, optimizer, lr_scheduler,
train_loader=train_loader):
pass

def configure_optimizers_fn():
optimizer = SGD(model.parameters(), lr=1e-3)
return optimizer, None

early_stopping_training_loop = EarlyExitCompressionTrainingLoop(config, compression_ctrl,
dump_checkpoints=False)
assert early_stopping_training_loop.runner._best_checkpoint is None

model = early_stopping_training_loop.run(model,
train_epoch_fn=train_fn,
validate_fn=partial(mock_validate_fn, init_step=False),
configure_optimizers_fn=configure_optimizers_fn)
assert early_stopping_training_loop.runner._best_checkpoint is not None


@pytest.mark.parametrize('aa_config', (
{
"accuracy_aware_training": {
Expand Down

0 comments on commit 7f0f48a

Please sign in to comment.