diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md index 3ea9f494..566ae4c9 100644 --- a/.github/CONTRIBUTING.md +++ b/.github/CONTRIBUTING.md @@ -7,16 +7,16 @@ We love your input! We want to make contributing to this project as easy and tra - Proposing new features - Becoming a maintainer -## We Develop with Github -We use github to host code, to track issues and feature requests, as well as accept pull requests. +## We Develop with GitHub +We use GitHub to host code, to track issues and feature requests, as well as accept pull requests. -## We Use [Github Flow](https://guides.github.com/introduction/flow/index.html), So All Code Changes Happen Through Pull Requests +## We Use [GitHub Flow](https://guides.github.com/introduction/flow/index.html), So All Code Changes Happen Through Pull Requests Pull requests are the best way to propose changes to the codebase. We actively welcome your pull requests: 1. Fork the repo and create your branch from the **`dev` branch**. 2. If you've added code that should be tested, you **must** ensure it is properly tested. 3. If you've changed APIs, update the documentation. -4. Ensure the Travis test suite passes. +4. Ensure the CI/CD test suite passes. 5. Make sure your code lints. 6. Submit that pull request! diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 3d3d4a04..daa2de3a 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -21,6 +21,7 @@ jobs: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | + python -m pip install --upgrade pip pip install -Ur requirements.txt pip install -Ur docs/requirements.txt pip install -e . diff --git a/.github/workflows/linting.yml b/.github/workflows/linting.yml index eab6483a..64bae30e 100644 --- a/.github/workflows/linting.yml +++ b/.github/workflows/linting.yml @@ -17,6 +17,7 @@ jobs: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | + python -m pip install --upgrade pip pip install -Ur requirements.txt pip install -Ur styling_requirements.txt pip install -Ur tests/requirements.txt diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 81648578..641707e1 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -17,6 +17,7 @@ jobs: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | + python -m pip install --upgrade pip pip install -Ur requirements.txt pip install -Ur tests/requirements.txt python setup.py develop @@ -38,6 +39,7 @@ jobs: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | + python -m pip install --upgrade pip pip install -Ur requirements.txt pip install -Ur tests/requirements.txt python setup.py develop diff --git a/.gitignore b/.gitignore index 25785c8e..b7bbdfbc 100644 --- a/.gitignore +++ b/.gitignore @@ -144,3 +144,5 @@ deepparse/version.py *.ckpt *mlruns/ + +*model/ \ No newline at end of file diff --git a/.release/bpemb.version b/.release/bpemb.version new file mode 100644 index 00000000..b31b8547 --- /dev/null +++ b/.release/bpemb.version @@ -0,0 +1 @@ +aa32fa918494b461202157c57734c374 diff --git a/.release/bpemb_attention.version b/.release/bpemb_attention.version new file mode 100644 index 00000000..bcc9ea1f --- /dev/null +++ b/.release/bpemb_attention.version @@ -0,0 +1 @@ +cfb190902476376573591c0ec6f91ece diff --git a/.release/fasttext.version b/.release/fasttext.version new file mode 100644 index 00000000..b19d26d5 --- /dev/null +++ b/.release/fasttext.version @@ -0,0 +1 @@ +f67a0517c70a314bdde0b8440f21139d diff --git a/.release/fasttext_attention.version b/.release/fasttext_attention.version new file mode 100644 index 00000000..12db9cc1 --- /dev/null +++ b/.release/fasttext_attention.version @@ -0,0 +1 @@ +a2b688bdfa2aa7c009bb7d980e352978 diff --git a/.release/model_version_release.md b/.release/model_version_release.md new file mode 100644 index 00000000..46e9616e --- /dev/null +++ b/.release/model_version_release.md @@ -0,0 +1,5 @@ +# How to Create a New Model's Version + +1. `md5sum > model.version` +2. Remove the model.cpkt text in `model.version` file +3. Update latests BPEMB and FastText hash in `tests/test_tools.py` \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index 885630e8..6b9c8daa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -313,4 +313,17 @@ suggested in the [documentation](https://pytorch.org/tutorials//intermediate/torch_compile_tutorial.html). It increases the performance by about 1/100. +## 0.9.7 + +- New models release with more meta-data. +- Add a feature to use an AddressParser from a URI. +- Add a feature to upload the trained model to a URI. +- Add an example of how to use URI for parsing from and uploading to. +- Improve error handling of `path_to_retrain_model`. +- Bug-fix pre-processor error. +- Add verbose override and improve verbosity handling in retrain. +- Bug-fix the broken FastText installation using `fasttext-wheel` instead of `fasttext` ( + see [here](https://github.com/facebookresearch/fastText/issues/512#issuecomment-1534519551) + and [here](https://github.com/facebookresearch/fastText/pull/1292)). + ## dev diff --git a/deepparse/__init__.py b/deepparse/__init__.py index 462804e9..dd64adac 100644 --- a/deepparse/__init__.py +++ b/deepparse/__init__.py @@ -2,4 +2,4 @@ from .fasttext_tools import * from .tools import * from .version import __version__ -from .weights_init import * +from .weights_tools import * diff --git a/deepparse/cli/parse.py b/deepparse/cli/parse.py index 39fc54f8..a075cf0a 100644 --- a/deepparse/cli/parse.py +++ b/deepparse/cli/parse.py @@ -50,7 +50,7 @@ def main(args=None) -> None: .. code-block:: sh - parse fasttext ./dataset.csv parsed_address.pckl --path_to_retrained_model ./path + parse fasttext ./dataset.csv parsed_address.pckl --path_to_model_weights ./path """ if args is None: # pragma: no cover diff --git a/deepparse/cli/parser_arguments_adder.py b/deepparse/cli/parser_arguments_adder.py index c50426f0..a398a9c8 100644 --- a/deepparse/cli/parser_arguments_adder.py +++ b/deepparse/cli/parser_arguments_adder.py @@ -108,7 +108,7 @@ def add_batch_size_arg(parser: ArgumentParser) -> None: def add_path_to_retrained_model_arg(parser: ArgumentParser) -> None: parser.add_argument( "--path_to_retrained_model", - help=wrap("A path to a retrained model to use for testing."), + help=wrap("A path to a retrained model to use. It can be an S3-URI."), type=str, default=None, ) diff --git a/deepparse/network/decoder.py b/deepparse/network/decoder.py index 8d578069..f0c347ab 100644 --- a/deepparse/network/decoder.py +++ b/deepparse/network/decoder.py @@ -6,7 +6,7 @@ import torch from torch import nn -from ..weights_init import weights_init +from .. import weights_init class Decoder(nn.Module): diff --git a/deepparse/network/encoder.py b/deepparse/network/encoder.py index 27d911f6..5fafb917 100644 --- a/deepparse/network/encoder.py +++ b/deepparse/network/encoder.py @@ -7,7 +7,7 @@ from torch import nn from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence -from ..weights_init import weights_init +from .. import weights_init class Encoder(nn.Module): diff --git a/deepparse/network/seq2seq.py b/deepparse/network/seq2seq.py index 38fc5b5a..5e05625b 100644 --- a/deepparse/network/seq2seq.py +++ b/deepparse/network/seq2seq.py @@ -4,7 +4,6 @@ import random import warnings from abc import ABC -from collections import OrderedDict from typing import Tuple, Union, List import torch @@ -12,6 +11,7 @@ from .decoder import Decoder from .encoder import Encoder +from .. import handle_weights_upload from ..tools import download_weights, latest_version @@ -113,20 +113,21 @@ def _load_pre_trained_weights(self, model_type: str, cache_dir: str, offline: bo ) download_weights(model_type, cache_dir, verbose=self.verbose) - all_layers_params = torch.load(model_path, map_location=self.device) - self.load_state_dict(all_layers_params) + self._load_weights(path_to_model_torch_archive=model_path) - def _load_weights(self, path_to_retrained_model: str) -> None: + def _load_weights(self, path_to_model_torch_archive: str) -> None: """ Method to load (into the network) the weights. Args: - path_to_retrained_model (str): The path to the fine-tuned model. + path_to_model_torch_archive (str): The path to the fine-tuned model Torch archive. """ - all_layers_params = torch.load(path_to_retrained_model, map_location=self.device) - if isinstance(all_layers_params, dict) and not isinstance(all_layers_params, OrderedDict): - # Case where we have a retrained model with a different tagging space - all_layers_params = all_layers_params.get("address_tagger_model") + all_layers_params = handle_weights_upload( + path_to_model_to_upload=path_to_model_torch_archive, device=self.device + ) + + # All the time, our torch archive include meta-data along with the model weights + all_layers_params = all_layers_params.get("address_tagger_model") self.load_state_dict(all_layers_params) def _encoder_step(self, to_predict: torch.Tensor, lengths: List, batch_size: int) -> Tuple: diff --git a/deepparse/parser/address_parser.py b/deepparse/parser/address_parser.py index 18718ae5..f824c134 100644 --- a/deepparse/parser/address_parser.py +++ b/deepparse/parser/address_parser.py @@ -14,6 +14,7 @@ from typing import Dict, List, Tuple, Union, Callable import torch +from cloudpathlib import CloudPath, S3Path from poutyne.framework import Experiment from torch.optim import SGD from torch.utils.data import DataLoader, Subset @@ -43,6 +44,7 @@ from ..pre_processing import trailing_whitespace_cleaning, double_whitespaces_cleaning from ..tools import CACHE_PATH, valid_poutyne_version from ..vectorizer import VectorizerFactory +from ..weights_tools import handle_weights_upload _pre_trained_tags_to_idx = { "StreetNumber": 0, @@ -86,7 +88,7 @@ class AddressParser: - ``"lightest"`` (the one using the less RAM and GPU usage) (equivalent to ``"fasttext-light"``), - ``"best"`` (the best accuracy performance) (equivalent to ``"bpemb"``). - The default value is ``"best"`` for the most accurate model. Ignored if ``path_to_retrained_model`` is not + The default value is ``"best"`` for the most accurate model. Ignored if ``path_to_model_weights`` is not ``None``. To further improve performance, consider using the models (fasttext or BPEmb) with their counterparts using an attention mechanism with the ``attention_mechanism`` flag. attention_mechanism (bool): Whether to use the model with an attention mechanism. The model will use an @@ -102,10 +104,13 @@ class AddressParser: The default value is GPU with the index ``0`` if it exists. Otherwise, the value is ``CPU``. rounding (int): The rounding to use when asking the probability of the tags. The default value is four digits. verbose (bool): Turn on/off the verbosity of the model weights download and loading. The default value is True. - path_to_retrained_model (Union[str, None]): The path to the retrained model to use for prediction. We will - infer the ``model_type`` of the retrained model. The default value is ``None``, meaning we use our + path_to_retrained_model (Union[S3Path, str, None]): The path to the retrained model to use for prediction. + We will infer the ``model_type`` of the retrained model. The default value is ``None``, meaning we use our pretrained model. If the retrained model uses an attention mechanism, ``attention_mechanism`` needs to - be set to True. + be set to True. The path_to_retrain_model can also be a S3-like (Azure, AWS, Google) bucket URI string path + (e.g. ``"s3://path/to/aws/s3/bucket.ckpt"``). Or it can be a ``S3Path`` S3-like URI using `cloudpathlib` + to handle S3-like bucket. See `cloudpathlib ` + for detail on supported S3 buckets provider and URI condition. The default value is None. cache_dir (Union[str, None]): The path to the cached directory to use for downloading (and loading) the embeddings model and the model pretrained weights. offline (bool): Whether or not the model is an offline one, meaning you have already downloaded the pre-trained @@ -117,7 +122,7 @@ class AddressParser: Note: For both networks, we will download the pretrained weights and embeddings in the ``.cache`` directory for the root user. The pretrained weights take at most 44 MB. The fastText embeddings take 6.8 GO, - the fastText-light embeddings take 3.3 GO and bpemb take 116 MB (in .cache/bpemb). + the fastText-light embeddings take 3.3 GO and bpemb take 116 MB (in ``".cache/bpemb"``). Also, one can download all the dependencies of our pretrained model using our CLI (e.g. download_model fasttext) before sending it to a node without access to Internet. @@ -164,15 +169,15 @@ class AddressParser: .. code-block:: python address_parser = AddressParser(model_type="fasttext", - path_to_retrained_model="/path_to_a_retrain_fasttext_model") + path_to_model_weights="/path_to_a_retrain_fasttext_model.ckpt") parse_address = address_parser("350 rue des Lilas Ouest Quebec city Quebec G1L 1B6") Using a retrained model trained on different tags .. code-block:: python - # We don't give the model_type since it's ignored when using path_to_retrained_model - address_parser = AddressParser(path_to_retrained_model="/path_to_a_retrain_fasttext_model") + # We don't give the model_type since it's ignored when using path_to_model_weights + address_parser = AddressParser(path_to_model_weights="/path_to_a_retrain_fasttext_model.ckpt") parse_address = address_parser("350 rue des Lilas Ouest Quebec city Quebec G1L 1B6") Using a retrained model with attention @@ -180,7 +185,7 @@ class AddressParser: .. code-block:: python address_parser = AddressParser(model_type="fasttext", - path_to_retrained_model="/path_to_a_retrain_fasttext_attention_model", + path_to_model_weights="/path_to_a_retrain_fasttext_attention_model.ckpt", attention_mechanism=True) parse_address = address_parser("350 rue des Lilas Ouest Quebec city Quebec G1L 1B6") @@ -193,6 +198,21 @@ class AddressParser: offline=True) parse_address = address_parser("350 rue des Lilas Ouest Quebec city Quebec G1L 1B6") + Using a retrained model in an S3-like bucket. + + .. code-block:: python + + address_parser = AddressParser(model_type="fasttext", + path_to_model_weights="s3://path/to/bucket.ckpt") + parse_address = address_parser("350 rue des Lilas Ouest Quebec city Quebec G1L 1B6") + + Using a retrained model in an S3-like bucket using CloudPathLib. + + .. code-block:: python + + address_parser = AddressParser(model_type="fasttext", + path_to_model_weights=CloudPath("s3://path/to/bucket.ckpt")) + parse_address = address_parser("350 rue des Lilas Ouest Quebec city Quebec G1L 1B6") """ def __init__( @@ -202,7 +222,7 @@ def __init__( device: Union[int, str, torch.device] = 0, rounding: int = 4, verbose: bool = True, - path_to_retrained_model: Union[str, None] = None, + path_to_retrained_model: Union[S3Path, str, None] = None, cache_dir: Union[str, None] = None, offline: bool = False, ) -> None: @@ -222,17 +242,21 @@ def __init__( seq2seq_kwargs = {} # Empty for default settings if path_to_retrained_model is not None: - checkpoint_weights = torch.load(path_to_retrained_model, map_location="cpu") + checkpoint_weights = handle_weights_upload(path_to_model_to_upload=path_to_retrained_model) if checkpoint_weights.get("model_type") is None: # Validate if we have the proper metadata, it has at least the parser model type # if no other thing have been modified. - raise RuntimeError( - "You are not using the proper retrained checkpoint. " + error_text = ( + "You are not using the proper retrained checkpoint for Deepparse, since we also export other" + "informations along with the model weights. " "When we retrain an AddressParser, by default, we create a " - "checkpoint name 'retrained_modeltype_address_parser.ckpt'. Be sure to use that" - "checkpoint since it includes some metadata for the reloading." + "checkpoint name 'retrained_modeltype_address_parser.ckpt'. " + "Where 'modeltype' is the AddressParser model type (e.g. 'fasttext', 'bpemb'). " + "The checkpoint name can also change if you give the retrained model a name. " + "Be sure to use that checkpoint since it includes some metadata for the reloading. " "See AddressParser.retrain for more details." ) + raise RuntimeError(error_text) if validate_if_new_seq2seq_params(checkpoint_weights): seq2seq_kwargs = checkpoint_weights.get("seq2seq_params") if validate_if_new_prediction_tags(checkpoint_weights): @@ -453,6 +477,7 @@ def retrain( seq2seq_params: Union[Dict, None] = None, layers_to_freeze: Union[str, None] = None, name_of_the_retrain_parser: Union[None, str] = None, + verbose: Union[None, bool] = None, ) -> List[Dict]: # pylint: disable=too-many-arguments, too-many-locals, too-many-branches, too-many-statements @@ -497,6 +522,12 @@ def retrain( logging_path (str): The logging path for the checkpoints. Poutyne will use the best one and reload the state if any checkpoints are there. Thus, an error will be raised if you change the model type. For example, you retrain a FastText model and then retrain a BPEmb in the same logging path directory. + The logging_path can also be a S3-like (Azure, AWS, Google) bucket URI string path + (e.g. ``"s3://path/to/aws/s3/bucket.ckpt"``). Or it can be a ``S3Path`` S3-like URI using `cloudpathlib` + to handle S3-like bucket. See `cloudpathlib ` + for detail on supported S3 buckets provider and URI condition. + If the logging_path is a S3 bucket, we will only save the best checkpoint to the S3 Bucket at the end + of training. By default, the path is ``./checkpoints``. disable_tensorboard (bool): To disable Poutyne automatic Tensorboard monitoring. By default, we disable them (true). @@ -542,6 +573,9 @@ def retrain( - if prediction_tags is not ``None``, the following tag: ``ModifiedPredictionTags``, - if seq2seq_params is not ``None``, the following tag: ``ModifiedSeq2SeqConfiguration``, and - if layers_to_freeze is not ``None``, the following tag: ``FreezedLayer{portion}``. + verbose (Union[None, bool]): To override the AddressParser verbosity for the test. When set to True or + False, it will override (but it does not change the AddressParser verbosity) the test verbosity. + If set to the default value None, the AddressParser verbosity is used as the test verbosity. Return: @@ -742,6 +776,10 @@ def retrain( batch_metrics=[accuracy], ) + # Handle the verbose overriding param + if verbose is None: + verbose = self.verbose + try: with_capturing_context = False if not valid_poutyne_version(min_major=1, min_minor=8): @@ -760,6 +798,7 @@ def retrain( callbacks=callbacks, disable_tensorboard=disable_tensorboard, capturing_context=with_capturing_context, + verbose=verbose, ) except RuntimeError as error: list_of_file_path = os.listdir(path=".") @@ -797,6 +836,7 @@ def retrain( else f"retrained_{self.model_type}_address_parser.ckpt" ) file_path = os.path.join(logging_path, file_name) + torch_save = { "address_tagger_model": exp.model.network.state_dict(), "model_type": self.model_type, @@ -817,7 +857,29 @@ def retrain( } ) - torch.save(torch_save, file_path) + if isinstance(file_path, S3Path): + # To handle CloudPath path_to_model_weights + try: + with file_path.open("wb") as file: + torch.save(torch_save, file) + except FileNotFoundError as error: + raise FileNotFoundError("The file in the S3 bucket was not found.") from error + + elif "s3://" in file_path: + file_path = CloudPath(file_path) + try: + with file_path.open("wb") as file: + torch.save(torch_save, file) + except FileNotFoundError as error: + raise FileNotFoundError("The file in the S3 bucket was not found.") from error + else: + try: + torch.save(torch_save, file_path) + except FileNotFoundError as error: + if "s3" in file_path or "//" in file_path or ":" in file_path: + raise FileNotFoundError( + "Are You trying to use a AWS S3 URI? If so path need to start with s3://." + ) from error return train_res def test( @@ -1114,8 +1176,8 @@ def _predict_pipeline(self, data: List) -> Tuple: """ return self.processor.process_for_inference(data) - @staticmethod def _retrain( + self, experiment: Experiment, train_generator: DatasetContainer, valid_generator: DatasetContainer, @@ -1124,6 +1186,7 @@ def _retrain( callbacks: List, disable_tensorboard: bool, capturing_context: bool, + verbose: Union[None, bool], ) -> List[Dict]: # pylint: disable=too-many-arguments # If Poutyne 1.7 and before, we capture poutyne print since it print some exception. @@ -1136,6 +1199,7 @@ def _retrain( seed=seed, callbacks=callbacks, disable_tensorboard=disable_tensorboard, + verbose=verbose, ) return train_res @@ -1250,9 +1314,12 @@ def _apply_pre_processors(self, addresses: List[str]) -> List[str]: res = [] for address in addresses: + processed_address = address + for pre_processor in self.pre_processors: processed_address = pre_processor(address) - res.append(" ".join(processed_address.split())) + + res.append(" ".join(processed_address.split())) return res def is_same_model_type(self, other) -> bool: diff --git a/deepparse/parser/tools.py b/deepparse/parser/tools.py index 8e77afe9..e4a3b495 100644 --- a/deepparse/parser/tools.py +++ b/deepparse/parser/tools.py @@ -1,7 +1,7 @@ -import math import os from typing import List, OrderedDict, Tuple +import math import numpy as np import torch @@ -134,7 +134,10 @@ def infer_model_type(checkpoint_weights: OrderedDict, attention_mechanism: bool) else: model_type = "fasttext" - if "decoder.linear_attention_mechanism_encoder_outputs.weight" in checkpoint_weights.keys(): + if ( + "decoder.linear_attention_mechanism_encoder_outputs.weight" + in checkpoint_weights.get("address_tagger_model").keys() + ): attention_mechanism = True return model_type, attention_mechanism diff --git a/deepparse/weights_init.py b/deepparse/weights_init.py deleted file mode 100644 index 5e6b13b2..00000000 --- a/deepparse/weights_init.py +++ /dev/null @@ -1,21 +0,0 @@ -from torch import nn -from torch.nn import init - - -def weights_init(m: nn.Module) -> None: - """ - Function to initialize the weights of a model layers. - - Usage: - network = Model() - network.apply(weight_init) - """ - if isinstance(m, nn.Linear): - init.xavier_normal_(m.weight.data) - init.normal_(m.bias.data) - elif isinstance(m, (nn.LSTM, nn.LSTMCell, nn.GRU, nn.GRUCell)): - for param in m.parameters(): - if len(param.shape) >= 2: - init.orthogonal_(param.data) - else: - init.normal_(param.data) diff --git a/deepparse/weights_tools.py b/deepparse/weights_tools.py new file mode 100644 index 00000000..dd5831bd --- /dev/null +++ b/deepparse/weights_tools.py @@ -0,0 +1,56 @@ +from typing import OrderedDict, Union + +import torch +from cloudpathlib import CloudPath, S3Path +from torch import nn +from torch.nn import init + + +def weights_init(m: nn.Module) -> None: + """ + Function to initialize the weights of a model layers. + + Usage: + network = Model() + network.apply(weight_init) + """ + if isinstance(m, nn.Linear): + init.xavier_normal_(m.weight.data) + init.normal_(m.bias.data) + elif isinstance(m, (nn.LSTM, nn.LSTMCell, nn.GRU, nn.GRUCell)): + for param in m.parameters(): + if len(param.shape) >= 2: + init.orthogonal_(param.data) + else: + init.normal_(param.data) + + +def handle_weights_upload( + path_to_model_to_upload: Union[str, S3Path], device: Union[str, torch.device] = "cpu" +) -> OrderedDict: + if isinstance(path_to_model_to_upload, S3Path): + # To handle CloudPath path_to_model_weights + try: + with path_to_model_to_upload.open("rb") as file: + checkpoint_weights = torch.load(file, map_location=device) + except FileNotFoundError as error: + raise FileNotFoundError("The file in the S3 bucket was not found.") from error + elif "s3://" in path_to_model_to_upload: + # To handle str S3-like URI. + path_to_model_to_upload = CloudPath(path_to_model_to_upload) + try: + with path_to_model_to_upload.open("rb") as file: + checkpoint_weights = torch.load(file, map_location=device) + except FileNotFoundError as error: + raise FileNotFoundError("The file in the S3 bucket was not found.") from error + else: + # Path is a local one (or a wrongly written S3 URI). + try: + checkpoint_weights = torch.load(path_to_model_to_upload, map_location=device) + except FileNotFoundError as error: + if "s3" in path_to_model_to_upload or "//" in path_to_model_to_upload or ":" in path_to_model_to_upload: + raise FileNotFoundError( + "Are You trying to use a AWS S3 URI? If so path need to start with s3://." + ) from error + raise FileNotFoundError(f"The file {path_to_model_to_upload} was not found.") from error + return checkpoint_weights diff --git a/docs/source/examples/fine_tuning_uri.rst b/docs/source/examples/fine_tuning_uri.rst new file mode 100644 index 00000000..c734f7ab --- /dev/null +++ b/docs/source/examples/fine_tuning_uri.rst @@ -0,0 +1,63 @@ +.. role:: hidden + :class: hidden-section + +Retrain a Pretrained Model +************************** + +.. code-block:: python + + import poutyne + + from deepparse import download_from_public_repository + from deepparse.dataset_container import PickleDatasetContainer + from deepparse.parser import AddressParser + + +First, let's download the train and test data from the public repository. + +.. code-block:: python + + saving_dir = "./data" + file_extension = "p" + training_dataset_name = "sample_incomplete_data" + test_dataset_name = "test_sample_data" + download_from_public_repository(training_dataset_name, saving_dir, file_extension=file_extension) + download_from_public_repository(test_dataset_name, saving_dir, file_extension=file_extension) + +Now let's create a training and test container. + +.. code-block:: python + + training_container = PickleDatasetContainer(os.path.join(saving_dir, + training_dataset_name + "." + file_extension)) + test_container = PickleDatasetContainer(os.path.join(saving_dir, + test_dataset_name + "." + file_extension)) + +We will retrain the ``FastText`` version of our pretrained model. + +.. code-block:: python + + path_to_your_uri = "s3:///fasttext.ckpt" + address_parser = AddressParser(model_type="fasttext", device=0, path_to_retrained_model=path_to_your_uri) + + +Now, let's retrain for ``5`` epochs using a batch size of ``8`` since the data is really small for the example. +Let's start with the default learning rate of ``0.01`` and use a learning rate scheduler to lower the learning rate as we progress. + +.. code-block:: python + + # Reduce LR by a factor of 10 each epoch + lr_scheduler = poutyne.StepLR(step_size=1, gamma=0.1) + +The retrained model best checkpoint (ckpt) will be saved in the S3 Bucket