diff --git a/aim/pytorch_lightning.py b/aim/pytorch_lightning.py index b1e424d89..50d10c1aa 100644 --- a/aim/pytorch_lightning.py +++ b/aim/pytorch_lightning.py @@ -1,2 +1,2 @@ # Alias to SDK PyTorch Lightning interface -from aim.sdk.adapters.pytorch_lightning import AimLogger # noqa F401 \ No newline at end of file +from aim.sdk.adapters.pytorch_lightning import AimLogger # noqa F401 diff --git a/aim/sdk/adapters/pytorch_lightning.py b/aim/sdk/adapters/pytorch_lightning.py index cabd734c3..e08508e3c 100644 --- a/aim/sdk/adapters/pytorch_lightning.py +++ b/aim/sdk/adapters/pytorch_lightning.py @@ -48,9 +48,9 @@ def __init__( self, repo: Optional[str] = None, experiment: Optional[str] = None, - train_metric_prefix: Optional[str] = 'train_', # deprecated - val_metric_prefix: Optional[str] = 'val_', # deprecated - test_metric_prefix: Optional[str] = 'test_', # deprecated + train_metric_prefix: Optional[str] = 'train_', # deprecated + val_metric_prefix: Optional[str] = 'val_', # deprecated + test_metric_prefix: Optional[str] = 'test_', # deprecated system_tracking_interval: Optional[int] = DEFAULT_SYSTEM_TRACKING_INT, log_system_params: Optional[bool] = True, capture_terminal_logs: Optional[bool] = True, @@ -89,7 +89,9 @@ def __init__( context_prefixes.pop('subset') # context_prefixes is now empty {} elif train_metric_prefix != 'train_' or val_metric_prefix != 'val_' or test_metric_prefix != 'test_': - raise ValueError('Arguments "train_metric_prefix" "val_metric_prefix" "train_metric_prefix" cannot be used in conjunction with "context_prefixes".') + raise ValueError( + 'Arguments "train_metric_prefix" "val_metric_prefix" "train_metric_prefix" cannot be used in conjunction with "context_prefixes".' + ) # Deprecation warnings if SUBSET_metric_prefix arguments are not default if train_metric_prefix != 'train_': msg = 'The argument "train_metric_prefix" is deprecated. Consider using "context_prefixes" instead.' @@ -180,14 +182,14 @@ def parse_context(self, name): for ctx, mappings in self._context_prefixes.items(): for category, prefix in mappings.items(): if name.startswith(prefix): - name = name[len(prefix):] + name = name[len(prefix) :] context[ctx] = category break # avoid prefix rename cascade for ctx, mappings in self._context_postfixes.items(): for category, postfix in mappings.items(): if name.endswith(postfix): - name = name[:-len(postfix)] + name = name[: -len(postfix)] context[ctx] = category break # avoid postfix rename cascade