From 4470a409a95a7c9d0d407d57744d5b1b86c5078c Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Wed, 16 Nov 2022 12:55:12 +0100 Subject: [PATCH] Various cleanups; type hint fixes incl. corresponding to PEP 484 See towards the bottom of https://peps.python.org/pep-0484/#union-types. It states that in past versions of the PEP, int = None would have been allowed, as it implicitly becomes Optional[int] = None. This is no longer the case, and we need to be explicit. --- src/setfit/data.py | 4 ++-- src/setfit/modeling.py | 16 +++++++--------- src/setfit/trainer.py | 28 +++++++++++++--------------- src/setfit/trainer_distillation.py | 12 ++++++------ src/setfit/utils.py | 4 ++-- 5 files changed, 30 insertions(+), 34 deletions(-) diff --git a/src/setfit/data.py b/src/setfit/data.py index abb6d8a2..5b2330a3 100644 --- a/src/setfit/data.py +++ b/src/setfit/data.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Dict, List, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple import pandas as pd import torch @@ -107,7 +107,7 @@ def sample_dataset(dataset: Dataset, label_column: str = "label", num_samples: i def create_fewshot_splits( - dataset: Dataset, sample_sizes: List[int], add_data_augmentation: bool = False, dataset_name: str = None + dataset: Dataset, sample_sizes: List[int], add_data_augmentation: bool = False, dataset_name: Optional[str] = None ) -> DatasetDict: """Creates training splits from the dataset with an equal number of samples per class (when possible).""" splits_ds = DatasetDict() diff --git a/src/setfit/modeling.py b/src/setfit/modeling.py index 6d28fa02..d2d643b3 100644 --- a/src/setfit/modeling.py +++ b/src/setfit/modeling.py @@ -79,7 +79,6 @@ def __init__( ) -> None: super(models.Dense, self).__init__() # init on models.Dense's parent: nn.Module - self.linear = None if in_features is not None: self.linear = nn.Linear(in_features, out_features, bias=bias) else: @@ -157,10 +156,9 @@ def predict(self, x_test: Union[torch.Tensor, "ndarray"]) -> Union[torch.Tensor, def get_loss_fn(self): if self.out_features == 1: # if single target return torch.nn.BCELoss() - else: - return torch.nn.CrossEntropyLoss() + return torch.nn.CrossEntropyLoss() - def get_config_dict(self) -> Dict[str, Union[int, float, bool]]: + def get_config_dict(self) -> Dict[str, Optional[Union[int, float, bool]]]: return { "in_features": self.in_features, "out_features": self.out_features, @@ -186,9 +184,9 @@ class SetFitModel(PyTorchModelHubMixin): def __init__( self, - model_body: Optional[nn.Module] = None, - model_head: Optional[Union[nn.Module, LogisticRegression]] = None, - multi_target_strategy: str = None, + model_body: Optional[SentenceTransformer] = None, + model_head: Optional[Union[SetFitHead, LogisticRegression]] = None, + multi_target_strategy: Optional[str] = None, l2_weight: float = 1e-2, ) -> None: super(SetFitModel, self).__init__() @@ -204,7 +202,7 @@ def fit( self, x_train: List[str], y_train: List[int], - num_epochs: Optional[int] = None, + num_epochs: int, batch_size: Optional[int] = None, learning_rate: Optional[float] = None, body_learning_rate: Optional[float] = None, @@ -243,7 +241,7 @@ def fit( self.model_head.fit(embeddings, y_train) def _prepare_dataloader( - self, x_train: List[str], y_train: List[int], batch_size: int, shuffle: bool = True + self, x_train: List[str], y_train: List[int], batch_size: Optional[int] = None, shuffle: bool = True ) -> DataLoader: dataset = SetFitDataset( x_train, diff --git a/src/setfit/trainer.py b/src/setfit/trainer.py index 5e96cd60..344d0b1e 100644 --- a/src/setfit/trainer.py +++ b/src/setfit/trainer.py @@ -73,10 +73,10 @@ class SetFitTrainer: def __init__( self, - model: "SetFitModel" = None, - train_dataset: "Dataset" = None, - eval_dataset: "Dataset" = None, - model_init: Callable[[], "SetFitModel"] = None, + model: Optional["SetFitModel"] = None, + train_dataset: Optional["Dataset"] = None, + eval_dataset: Optional["Dataset"] = None, + model_init: Optional[Callable[[], "SetFitModel"]] = None, metric: Union[str, Callable[["Dataset", "Dataset"], Dict[str, float]]] = "accuracy", loss_class=losses.CosineSimilarityLoss, num_iterations: int = 20, @@ -84,7 +84,7 @@ def __init__( learning_rate: float = 2e-5, batch_size: int = 16, seed: int = 42, - column_mapping: Dict[str, str] = None, + column_mapping: Optional[Dict[str, str]] = None, use_amp: bool = False, warmup_proportion: float = 0.1, distance_metric: Callable = BatchHardTripletLossDistanceFunction.cosine_distance, @@ -97,6 +97,7 @@ def __init__( self.train_dataset = train_dataset self.eval_dataset = eval_dataset + self.model_init = model_init self.metric = metric self.loss_class = loss_class self.num_iterations = num_iterations @@ -112,7 +113,6 @@ def __init__( if model is None: if model_init is not None: - self.model_init = model_init model = self.call_model_init() else: raise RuntimeError("`SetFitTrainer` requires either a `model` or `model_init` argument") @@ -120,8 +120,6 @@ def __init__( if model_init is not None: raise RuntimeError("`SetFitTrainer` requires either a `model` or `model_init` argument, but not both") - self.model_init = model_init - self.model = model self.hp_search_backend = None self._freeze = True # If True, will train the body only; otherwise, train the body and head @@ -183,7 +181,7 @@ def apply_hyperparameters(self, params: Dict[str, Any], final_model: bool = Fals setattr(self, key, value) elif number_of_arguments(self.model_init) == 0: # we do not warn if model_init could be using it logger.warning( - f"Trying to set {key} in the hyperparameter search but there is no corresponding field in " + f"Trying to set {key!r} in the hyperparameter search but there is no corresponding field in " "`SetFitTrainer`, and `model_init` does not take any arguments." ) @@ -208,17 +206,17 @@ def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]): logger.info(f"Trial: {params}") self.apply_hyperparameters(params, final_model=False) - def call_model_init(self, params: Dict[str, Any] = None): + def call_model_init(self, params: Optional[Dict[str, Any]] = None): model_init_argcount = number_of_arguments(self.model_init) if model_init_argcount == 0: model = self.model_init() elif model_init_argcount == 1: model = self.model_init(params) else: - raise RuntimeError("model_init should have 0 or 1 argument.") + raise RuntimeError("`model_init` should have 0 or 1 argument.") if model is None: - raise RuntimeError("model_init should not return None.") + raise RuntimeError("`model_init` should not return None.") return model @@ -259,7 +257,7 @@ def train( learning_rate: Optional[float] = None, body_learning_rate: Optional[float] = None, l2_weight: Optional[float] = None, - trial: Union["optuna.Trial", Dict[str, Any]] = None, + trial: Optional[Union["optuna.Trial", Dict[str, Any]]] = None, ): """ Main training entry point. @@ -287,7 +285,7 @@ def train( self._hp_search_setup(trial) # sets trainer parameters and initializes model if self.train_dataset is None: - raise ValueError("SetFitTrainer: training requires a train_dataset.") + raise ValueError("Training requires a `train_dataset` given to the `SetFitTrainer` initialization.") self._validate_column_mapping(self.train_dataset) train_dataset = self.train_dataset @@ -501,7 +499,7 @@ def push_to_hub( organization: Optional[str] = None, private: Optional[bool] = None, api_endpoint: Optional[str] = None, - use_auth_token: Union[bool, str] = None, + use_auth_token: Optional[Union[bool, str]] = None, git_user: Optional[str] = None, git_email: Optional[str] = None, config: Optional[dict] = None, diff --git a/src/setfit/trainer_distillation.py b/src/setfit/trainer_distillation.py index bd45e1cb..704eb770 100644 --- a/src/setfit/trainer_distillation.py +++ b/src/setfit/trainer_distillation.py @@ -66,10 +66,10 @@ class DistillationSetFitTrainer(SetFitTrainer): def __init__( self, teacher_model: "SetFitModel", - student_model: "SetFitModel" = None, - train_dataset: "Dataset" = None, - eval_dataset: "Dataset" = None, - model_init: Callable[[], "SetFitModel"] = None, + student_model: Optional["SetFitModel"] = None, + train_dataset: Optional["Dataset"] = None, + eval_dataset: Optional["Dataset"] = None, + model_init: Optional[Callable[[], "SetFitModel"]] = None, metric: Union[str, Callable[["Dataset", "Dataset"], Dict[str, float]]] = "accuracy", loss_class: torch.nn.Module = losses.CosineSimilarityLoss, num_iterations: int = 20, @@ -77,7 +77,7 @@ def __init__( learning_rate: float = 2e-5, batch_size: int = 16, seed: int = 42, - column_mapping: Dict[str, str] = None, + column_mapping: Optional[Dict[str, str]] = None, use_amp: bool = False, warmup_proportion: float = 0.1, ) -> None: @@ -108,7 +108,7 @@ def train( learning_rate: Optional[float] = None, body_learning_rate: Optional[float] = None, l2_weight: Optional[float] = None, - trial: Union["optuna.Trial", Dict[str, Any]] = None, + trial: Optional[Union["optuna.Trial", Dict[str, Any]]] = None, ): """ Main training entry point. diff --git a/src/setfit/utils.py b/src/setfit/utils.py index 4f14ccd0..409edb05 100644 --- a/src/setfit/utils.py +++ b/src/setfit/utils.py @@ -1,7 +1,7 @@ from contextlib import contextmanager from dataclasses import dataclass, field from time import monotonic_ns -from typing import Any, Dict, List, NamedTuple, Tuple +from typing import Any, Dict, List, NamedTuple, Optional, Tuple from datasets import Dataset, DatasetDict, load_dataset from sentence_transformers import losses @@ -101,7 +101,7 @@ class Benchmark: bench.summary() """ - out_path: str = None + out_path: Optional[str] = None summary_msg: str = field(default_factory=str) def print(self, msg: str) -> None: