Skip to content

Commit

Permalink
nit: Fix/add some type annotations (#1982)
Browse files Browse the repository at this point in the history
Co-authored-by: Felipe Mello <[email protected]>
  • Loading branch information
bradhilton and felipemello1 authored Nov 13, 2024
1 parent 1eb7785 commit 4b6877a
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 19 deletions.
4 changes: 2 additions & 2 deletions docs/source/deep_dives/configs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ this by taking a look at the :func:`~torchtune.config.instantiate` API.
def instantiate(
config: DictConfig,
*args: Tuple[Any, ...],
**kwargs: Dict[str, Any],
*args: Any,
**kwargs: Any,
)
:func:`~torchtune.config.instantiate` also accepts positional arguments
Expand Down
12 changes: 6 additions & 6 deletions torchtune/config/_instantiate.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ def _create_component(
_component_: Callable[..., Any],
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
):
) -> Any:
return _component_(*args, **kwargs)


def _instantiate_node(node: Dict[str, Any], *args: Tuple[Any, ...]):
def _instantiate_node(node: Dict[str, Any], *args: Any) -> Any:
"""
Creates the object specified in _component_ field with provided positional args
and kwargs already merged. Raises an InstantiationError if _component_ is not specified.
Expand All @@ -40,8 +40,8 @@ def _instantiate_node(node: Dict[str, Any], *args: Tuple[Any, ...]):

def instantiate(
config: DictConfig,
*args: Tuple[Any, ...],
**kwargs: Dict[str, Any],
*args: Any,
**kwargs: Any,
) -> Any:
"""
Given a DictConfig with a _component_ field specifying the object to instantiate and
Expand All @@ -60,8 +60,8 @@ def instantiate(
config (DictConfig): a single field in the OmegaConf object parsed from the yaml file.
This is expected to have a _component_ field specifying the path of the object
to instantiate.
*args (Tuple[Any, ...]): positional arguments to pass to the object to instantiate.
**kwargs (Dict[str, Any]): keyword arguments to pass to the object to instantiate.
*args (Any): positional arguments to pass to the object to instantiate.
**kwargs (Any): keyword arguments to pass to the object to instantiate.
Examples:
>>> config.yaml:
Expand Down
4 changes: 2 additions & 2 deletions torchtune/config/_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def parse_known_args(self, *args, **kwargs) -> Tuple[Namespace, List[str]]:
return namespace, unknown_args


def parse(recipe_main: Recipe) -> Callable[[Recipe], Any]:
def parse(recipe_main: Recipe) -> Callable[..., Any]:
"""
Decorator that handles parsing the config file and CLI overrides
for a recipe. Use it on the recipe's main function.
Expand All @@ -83,7 +83,7 @@ def parse(recipe_main: Recipe) -> Callable[[Recipe], Any]:
>>> tune my_recipe --config config.yaml foo=bar
Returns:
Callable[[Recipe], Any]: the decorated main
Callable[..., Any]: the decorated main
"""

@functools.wraps(recipe_main)
Expand Down
18 changes: 9 additions & 9 deletions torchtune/modules/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import sys
from collections import OrderedDict
from functools import partial
from typing import Any, Dict, Generator, Optional, Tuple
from typing import Any, Dict, Generator, Optional
from warnings import warn

import torch
Expand All @@ -24,10 +24,10 @@
def reparametrize_as_dtype_state_dict_post_hook(
model: nn.Module,
state_dict: Dict[str, Any],
*args: Tuple[Any, ...],
*args: Any,
dtype: torch.dtype = torch.bfloat16,
offload_to_cpu: bool = True,
**kwargs: Dict[Any, Any],
**kwargs: Any,
):
"""
A state_dict hook that replaces NF4 tensors with their restored
Expand All @@ -47,10 +47,10 @@ def reparametrize_as_dtype_state_dict_post_hook(
Args:
model (nn.Module): the model to take ``state_dict()`` on
state_dict (Dict[str, Any]): the state dict to modify
*args (Tuple[Any, ...]): Unused args passed when running this as a state_dict hook.
*args (Any): Unused args passed when running this as a state_dict hook.
dtype (torch.dtype): the dtype to restore the weight to. Default is ``torch.bfloat16``.
offload_to_cpu (bool): whether to offload the restored weight to CPU. Default is ``True``.
**kwargs (Dict[Any, Any]): Unused keyword args passed when running this as a state_dict hook.
**kwargs (Any): Unused keyword args passed when running this as a state_dict hook.
"""
for k, v in state_dict.items():
if isinstance(v, NF4Tensor):
Expand All @@ -62,10 +62,10 @@ def reparametrize_as_dtype_state_dict_post_hook(
def _low_ram_reparametrize_as_dtype_state_dict_post_hook(
model: nn.Module,
state_dict: Dict[str, Any],
*args: Tuple[Any, ...],
*args: Any,
dtype: torch.dtype = torch.bfloat16,
offload_to_cpu: bool = True,
**kwargs: Dict[Any, Any],
**kwargs: Any,
):
"""
A state_dict hook that replaces NF4 tensors with their restored
Expand All @@ -88,10 +88,10 @@ def _low_ram_reparametrize_as_dtype_state_dict_post_hook(
Args:
model (nn.Module): the model to take ``state_dict()`` on
state_dict (Dict[str, Any]): the state dict to modify
*args (Tuple[Any, ...]): Unused args passed when running this as a state_dict hook.
*args (Any): Unused args passed when running this as a state_dict hook.
dtype (torch.dtype): the dtype to restore the weight to. Default is ``torch.bfloat16``.
offload_to_cpu (bool): whether to offload the restored weight to CPU. Default is ``True``.
**kwargs (Dict[Any, Any]): Unused keyword args passed when running this as a state_dict hook.
**kwargs (Any): Unused keyword args passed when running this as a state_dict hook.
"""
# Create a state dict of FakeTensors that matches the state_dict
mode = FakeTensorMode()
Expand Down

0 comments on commit 4b6877a

Please sign in to comment.