Skip to content

Commit

Permalink
Unifying training argument type annotations (#17934)
Browse files Browse the repository at this point in the history
* doc: Unify training arg type annotations

* wip: extracting enum type from Union

* blackening
  • Loading branch information
jannisborn authored Jun 30, 2022
1 parent 205bc41 commit 4f8361a
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 8 deletions.
6 changes: 5 additions & 1 deletion src/transformers/hf_argparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,11 @@ def _parse_dataclass_field(parser: ArgumentParser, field: dataclasses.Field):
" the argument parser only supports one type per argument."
f" Problem encountered in field '{field.name}'."
)
if bool not in field.type.__args__:
if type(None) not in field.type.__args__:
# filter `str` in Union
field.type = field.type.__args__[0] if field.type.__args__[1] == str else field.type.__args__[1]
origin_type = getattr(field.type, "__origin__", field.type)
elif bool not in field.type.__args__:
# filter `NoneType` in Union (except for `Union[bool, NoneType]`)
field.type = (
field.type.__args__[0] if isinstance(None, field.type.__args__[1]) else field.type.__args__[1]
Expand Down
14 changes: 7 additions & 7 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from dataclasses import asdict, dataclass, field
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union

from .debug_utils import DebugOption
from .trainer_utils import (
Expand Down Expand Up @@ -493,7 +493,7 @@ class TrainingArguments:
do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."})
do_predict: bool = field(default=False, metadata={"help": "Whether to run predictions on the test set."})
evaluation_strategy: IntervalStrategy = field(
evaluation_strategy: Union[IntervalStrategy, str] = field(
default="no",
metadata={"help": "The evaluation strategy to use."},
)
Expand Down Expand Up @@ -559,7 +559,7 @@ class TrainingArguments:
default=-1,
metadata={"help": "If > 0: set total number of training steps to perform. Override num_train_epochs."},
)
lr_scheduler_type: SchedulerType = field(
lr_scheduler_type: Union[SchedulerType, str] = field(
default="linear",
metadata={"help": "The scheduler type to use."},
)
Expand Down Expand Up @@ -596,14 +596,14 @@ class TrainingArguments:
},
)
logging_dir: Optional[str] = field(default=None, metadata={"help": "Tensorboard log dir."})
logging_strategy: IntervalStrategy = field(
logging_strategy: Union[IntervalStrategy, str] = field(
default="steps",
metadata={"help": "The logging strategy to use."},
)
logging_first_step: bool = field(default=False, metadata={"help": "Log the first global_step"})
logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."})
logging_nan_inf_filter: bool = field(default=True, metadata={"help": "Filter nan and inf losses for logging."})
save_strategy: IntervalStrategy = field(
save_strategy: Union[IntervalStrategy, str] = field(
default="steps",
metadata={"help": "The checkpoint save strategy to use."},
)
Expand Down Expand Up @@ -815,7 +815,7 @@ class TrainingArguments:
label_smoothing_factor: float = field(
default=0.0, metadata={"help": "The label smoothing epsilon to apply (zero means no label smoothing)."}
)
optim: OptimizerNames = field(
optim: Union[OptimizerNames, str] = field(
default="adamw_hf",
metadata={"help": "The optimizer to use."},
)
Expand Down Expand Up @@ -868,7 +868,7 @@ class TrainingArguments:
hub_model_id: Optional[str] = field(
default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
)
hub_strategy: HubStrategy = field(
hub_strategy: Union[HubStrategy, str] = field(
default="every_save",
metadata={"help": "The hub strategy to use when `--push_to_hub` is activated."},
)
Expand Down

0 comments on commit 4f8361a

Please sign in to comment.