diff --git a/src/setfit/data.py b/src/setfit/data.py index abb6d8a2..5b2330a3 100644 --- a/src/setfit/data.py +++ b/src/setfit/data.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Dict, List, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple import pandas as pd import torch @@ -107,7 +107,7 @@ def sample_dataset(dataset: Dataset, label_column: str = "label", num_samples: i def create_fewshot_splits( - dataset: Dataset, sample_sizes: List[int], add_data_augmentation: bool = False, dataset_name: str = None + dataset: Dataset, sample_sizes: List[int], add_data_augmentation: bool = False, dataset_name: Optional[str] = None ) -> DatasetDict: """Creates training splits from the dataset with an equal number of samples per class (when possible).""" splits_ds = DatasetDict() diff --git a/src/setfit/modeling.py b/src/setfit/modeling.py index 976b0eda..df18c92e 100644 --- a/src/setfit/modeling.py +++ b/src/setfit/modeling.py @@ -77,7 +77,6 @@ def __init__( ) -> None: super(models.Dense, self).__init__() # init on models.Dense's parent: nn.Module - self.linear = None if in_features is not None: self.linear = nn.Linear(in_features, out_features, bias=bias) else: @@ -155,8 +154,7 @@ def predict(self, x_test: Union[torch.Tensor, "ndarray"]) -> Union[torch.Tensor, def get_loss_fn(self): if self.out_features == 1: # if single target return torch.nn.BCELoss() - else: - return torch.nn.CrossEntropyLoss() + return torch.nn.CrossEntropyLoss() @property def device(self) -> torch.device: @@ -167,7 +165,7 @@ def device(self) -> torch.device: """ return next(self.parameters()).device - def get_config_dict(self) -> Dict[str, Union[int, float, bool]]: + def get_config_dict(self) -> Dict[str, Optional[Union[int, float, bool]]]: return { "in_features": self.in_features, "out_features": self.out_features, @@ -193,9 +191,9 @@ class SetFitModel(PyTorchModelHubMixin): def __init__( self, - model_body: Optional[nn.Module] = None, - model_head: Optional[Union[nn.Module, LogisticRegression]] = None, - multi_target_strategy: str = None, + model_body: Optional[SentenceTransformer] = None, + model_head: Optional[Union[SetFitHead, LogisticRegression]] = None, + multi_target_strategy: Optional[str] = None, l2_weight: float = 1e-2, normalize_embeddings: bool = False, ) -> None: @@ -212,7 +210,7 @@ def fit( self, x_train: List[str], y_train: List[int], - num_epochs: Optional[int] = None, + num_epochs: int, batch_size: Optional[int] = None, learning_rate: Optional[float] = None, body_learning_rate: Optional[float] = None, @@ -257,7 +255,7 @@ def _prepare_dataloader( self, x_train: List[str], y_train: List[int], - batch_size: int, + batch_size: Optional[int] = None, max_length: Optional[int] = None, shuffle: bool = True, ) -> DataLoader: diff --git a/src/setfit/trainer.py b/src/setfit/trainer.py index 89b41987..a2f90054 100644 --- a/src/setfit/trainer.py +++ b/src/setfit/trainer.py @@ -78,10 +78,10 @@ class SetFitTrainer: def __init__( self, - model: "SetFitModel" = None, - train_dataset: "Dataset" = None, - eval_dataset: "Dataset" = None, - model_init: Callable[[], "SetFitModel"] = None, + model: Optional["SetFitModel"] = None, + train_dataset: Optional["Dataset"] = None, + eval_dataset: Optional["Dataset"] = None, + model_init: Optional[Callable[[], "SetFitModel"]] = None, metric: Union[str, Callable[["Dataset", "Dataset"], Dict[str, float]]] = "accuracy", loss_class=losses.CosineSimilarityLoss, num_iterations: int = 20, @@ -89,7 +89,7 @@ def __init__( learning_rate: float = 2e-5, batch_size: int = 16, seed: int = 42, - column_mapping: Dict[str, str] = None, + column_mapping: Optional[Dict[str, str]] = None, use_amp: bool = False, warmup_proportion: float = 0.1, distance_metric: Callable = BatchHardTripletLossDistanceFunction.cosine_distance, @@ -103,6 +103,7 @@ def __init__( self.train_dataset = train_dataset self.eval_dataset = eval_dataset + self.model_init = model_init self.metric = metric self.loss_class = loss_class self.num_iterations = num_iterations @@ -119,7 +120,6 @@ def __init__( if model is None: if model_init is not None: - self.model_init = model_init model = self.call_model_init() else: raise RuntimeError("`SetFitTrainer` requires either a `model` or `model_init` argument") @@ -127,8 +127,6 @@ def __init__( if model_init is not None: raise RuntimeError("`SetFitTrainer` requires either a `model` or `model_init` argument, but not both") - self.model_init = model_init - self.model = model self.hp_search_backend = None self._freeze = True # If True, will train the body only; otherwise, train the body and head @@ -190,7 +188,7 @@ def apply_hyperparameters(self, params: Dict[str, Any], final_model: bool = Fals setattr(self, key, value) elif number_of_arguments(self.model_init) == 0: # we do not warn if model_init could be using it logger.warning( - f"Trying to set {key} in the hyperparameter search but there is no corresponding field in " + f"Trying to set {key!r} in the hyperparameter search but there is no corresponding field in " "`SetFitTrainer`, and `model_init` does not take any arguments." ) @@ -215,17 +213,17 @@ def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]): logger.info(f"Trial: {params}") self.apply_hyperparameters(params, final_model=False) - def call_model_init(self, params: Dict[str, Any] = None): + def call_model_init(self, params: Optional[Dict[str, Any]] = None): model_init_argcount = number_of_arguments(self.model_init) if model_init_argcount == 0: model = self.model_init() elif model_init_argcount == 1: model = self.model_init(params) else: - raise RuntimeError("model_init should have 0 or 1 argument.") + raise RuntimeError("`model_init` should have 0 or 1 argument.") if model is None: - raise RuntimeError("model_init should not return None.") + raise RuntimeError("`model_init` should not return None.") return model @@ -267,7 +265,7 @@ def train( body_learning_rate: Optional[float] = None, l2_weight: Optional[float] = None, max_length: Optional[int] = None, - trial: Union["optuna.Trial", Dict[str, Any]] = None, + trial: Optional[Union["optuna.Trial", Dict[str, Any]]] = None, show_progress_bar: bool = True, ): """ @@ -302,7 +300,7 @@ def train( self._hp_search_setup(trial) # sets trainer parameters and initializes model if self.train_dataset is None: - raise ValueError("SetFitTrainer: training requires a train_dataset.") + raise ValueError("Training requires a `train_dataset` given to the `SetFitTrainer` initialization.") self._validate_column_mapping(self.train_dataset) train_dataset = self.train_dataset @@ -516,7 +514,7 @@ def push_to_hub( organization: Optional[str] = None, private: Optional[bool] = None, api_endpoint: Optional[str] = None, - use_auth_token: Union[bool, str] = None, + use_auth_token: Optional[Union[bool, str]] = None, git_user: Optional[str] = None, git_email: Optional[str] = None, config: Optional[dict] = None, diff --git a/src/setfit/trainer_distillation.py b/src/setfit/trainer_distillation.py index 31d3a2f1..2a4d058b 100644 --- a/src/setfit/trainer_distillation.py +++ b/src/setfit/trainer_distillation.py @@ -66,10 +66,10 @@ class DistillationSetFitTrainer(SetFitTrainer): def __init__( self, teacher_model: "SetFitModel", - student_model: "SetFitModel" = None, - train_dataset: "Dataset" = None, - eval_dataset: "Dataset" = None, - model_init: Callable[[], "SetFitModel"] = None, + student_model: Optional["SetFitModel"] = None, + train_dataset: Optional["Dataset"] = None, + eval_dataset: Optional["Dataset"] = None, + model_init: Optional[Callable[[], "SetFitModel"]] = None, metric: Union[str, Callable[["Dataset", "Dataset"], Dict[str, float]]] = "accuracy", loss_class: torch.nn.Module = losses.CosineSimilarityLoss, num_iterations: int = 20, @@ -77,7 +77,7 @@ def __init__( learning_rate: float = 2e-5, batch_size: int = 16, seed: int = 42, - column_mapping: Dict[str, str] = None, + column_mapping: Optional[Dict[str, str]] = None, use_amp: bool = False, warmup_proportion: float = 0.1, ) -> None: @@ -108,7 +108,7 @@ def train( learning_rate: Optional[float] = None, body_learning_rate: Optional[float] = None, l2_weight: Optional[float] = None, - trial: Union["optuna.Trial", Dict[str, Any]] = None, + trial: Optional[Union["optuna.Trial", Dict[str, Any]]] = None, show_progress_bar: bool = True, ): """ diff --git a/src/setfit/utils.py b/src/setfit/utils.py index 4f14ccd0..409edb05 100644 --- a/src/setfit/utils.py +++ b/src/setfit/utils.py @@ -1,7 +1,7 @@ from contextlib import contextmanager from dataclasses import dataclass, field from time import monotonic_ns -from typing import Any, Dict, List, NamedTuple, Tuple +from typing import Any, Dict, List, NamedTuple, Optional, Tuple from datasets import Dataset, DatasetDict, load_dataset from sentence_transformers import losses @@ -101,7 +101,7 @@ class Benchmark: bench.summary() """ - out_path: str = None + out_path: Optional[str] = None summary_msg: str = field(default_factory=str) def print(self, msg: str) -> None: