Skip to content

Commit

Permalink
Apply minor cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
arxyzan committed Oct 4, 2023
1 parent 87ec7b6 commit ec2fe0f
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions hezar/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,13 +362,13 @@ def _unpack_prediction_kwargs(self, **kwargs):
Returns:
A 3-sized tuple of (preprocess_kwargs, forward_kwargs, post_process_kwargs)
"""
# Whether to use forward or generate based on model type (Model or GenerativeModel)
inference_fn = type(self).generate if hasattr(self, "generate") else type(self).forward

# Get keyword arguments from the child class (ignore self, first arg and **kwargs)
preprocess_kwargs_keys = list(dict(inspect.signature(type(self).preprocess).parameters).keys())[2:-1]
post_process_kwargs_keys = list(dict(inspect.signature(type(self).post_process).parameters).keys())[2:-1]
if hasattr(self, "generate"):
# Get `generate` kwargs instead of `forward` if the model is generative
forward_kwargs_keys = list(dict(inspect.signature(type(self).generate).parameters).keys())[2:-1]
else:
forward_kwargs_keys = list(dict(inspect.signature(type(self).forward).parameters).keys())[2:-1]
forward_kwargs_keys = list(dict(inspect.signature(inference_fn).parameters).keys())[2:-1]

preprocess_kwargs = {k: kwargs.get(k) for k in preprocess_kwargs_keys if k in kwargs}
forward_kwargs = {k: kwargs.get(k) for k in forward_kwargs_keys if k in kwargs}
Expand Down

0 comments on commit ec2fe0f

Please sign in to comment.