Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
hotfix (#261)
Browse files Browse the repository at this point in the history
* hotfix

* update
  • Loading branch information
tchaton authored May 5, 2021
1 parent ec404e8 commit b3c049c
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
7 changes: 6 additions & 1 deletion flash/data/data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,7 +575,12 @@ def __init__(self, func_to_wrap: Callable, model: 'Task') -> None:
def __call__(self, *args, **kwargs):
outputs = self.func(*args, **kwargs)

internal_running_state = self.internal_mapping[self.model.trainer._running_stage]
try:
stage = self.model.trainer._running_stage
except AttributeError:
stage = self.model.trainer.state.stage

internal_running_state = self.internal_mapping[stage]
additional_func = self._stage_mapping.get(internal_running_state, None)

if additional_func:
Expand Down
1 change: 1 addition & 0 deletions flash/video/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
class VideoClassifierFinetuning(BaseFinetuning):

def __init__(self, num_layers: int = 5, train_bn: bool = True, unfreeze_epoch: int = 1):
super().__init__()
self.num_layers = num_layers
self.train_bn = train_bn
self.unfreeze_epoch = unfreeze_epoch
Expand Down

0 comments on commit b3c049c

Please sign in to comment.