Skip to content

Commit

Permalink
Fix dtype getter (#17668)
Browse files Browse the repository at this point in the history
* Fix dtype getters

* Proper fix for dtype getter

* Style and commant

* Always use last for consistency

* Quality
  • Loading branch information
sgugger authored Jun 13, 2022
1 parent 7308358 commit a1344db
Showing 1 changed file with 17 additions and 15 deletions.
32 changes: 17 additions & 15 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def get_first_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "Modu
try:
return next(parameter.parameters()).dtype
except StopIteration:
# For nn.DataParallel compatibility in PyTorch 1.5
# For nn.DataParallel compatibility in PyTorch > 1.5

def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
Expand All @@ -152,31 +152,33 @@ def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:

def get_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]):
"""
Returns the first found floating dtype in parameters if there is one, otherwise returns the first dtype it found.
Returns the first found floating dtype in parameters if there is one, otherwise returns the last dtype it found.
"""
try:
for t in parameter.parameters():
if t.is_floating_point():
return t.dtype
# if no floating dtype was found return whatever the first dtype is
else:
return next(parameter.parameters()).dtype
last_dtype = None
for t in parameter.parameters():
last_dtype = t.dtype
if t.is_floating_point():
return t.dtype

except StopIteration:
# For nn.DataParallel compatibility in PyTorch 1.5
if last_dtype is not None:
# if no floating dtype was found return whatever the first dtype is
return last_dtype

else:
# For nn.DataParallel compatibility in PyTorch > 1.5
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
return tuples

gen = parameter._named_members(get_members_fn=find_tensor_attributes)
last_tuple = None
for tuple in gen:
last_tuple = tuple
if tuple[1].is_floating_point():
return tuple[1].dtype
# fallback to any dtype the model has even if not floating
else:
first_tuple = next(gen)
return first_tuple[1].dtype

# fallback to the last dtype
return last_tuple[1].dtype


def get_state_dict_float_dtype(state_dict):
Expand Down

0 comments on commit a1344db

Please sign in to comment.