Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sacrebleu: Add more tokenizers for SacreBLEU metric #2068

Merged
merged 30 commits into from
Oct 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
b57e2ff
sacrebleu[feat]: Add ja-mecab tokenizer
stancld Sep 10, 2023
520e302
sacrebleu[feat]: Add ko-mecab tokenizer
stancld Sep 10, 2023
8e404f4
sacrebleu[feat]: Add flores101 and flores200 tokenizers
stancld Sep 10, 2023
189184f
Update changelog
stancld Sep 10, 2023
1ab9228
requiremtens[test/text]: Bump up minimal sacrebleu version
stancld Sep 23, 2023
69b625b
fix: Ignore two typing issues mypy fails to derive
stancld Sep 23, 2023
ed0fc4c
infra: Try MacOS-12
stancld Sep 24, 2023
ad78c40
Merge branch 'master' into feature/sacrebleu-tokenizers
SkafteNicki Sep 24, 2023
3810a11
changelog
SkafteNicki Sep 24, 2023
9f1678d
Try to fix docs
stancld Sep 24, 2023
d35a0df
Merge branch 'master' into feature/sacrebleu-tokenizers
SkafteNicki Sep 28, 2023
f076898
text formatting
SkafteNicki Sep 28, 2023
45cb151
change name to not conflict with hf
SkafteNicki Sep 28, 2023
60c28c1
Update requirements/text.txt
SkafteNicki Sep 28, 2023
a3bcbd3
Test equivalence of AVAILABLE_TOKENIZERS and corresponding type annot…
stancld Oct 1, 2023
d22f247
Update definition of _FLORES_LOCAL_DIR
stancld Oct 1, 2023
716a719
Merge branch 'master' into feature/sacrebleu-tokenizers
stancld Oct 1, 2023
7db58fd
Merge branch 'master' into feature/sacrebleu-tokenizers
Borda Oct 2, 2023
cf4bca1
Merge branch 'master' into feature/sacrebleu-tokenizers
Borda Oct 2, 2023
a02f60e
Merge branch 'master' into feature/sacrebleu-tokenizers
Borda Oct 3, 2023
b2d9411
Merge branch 'master' into feature/sacrebleu-tokenizers
Borda Oct 4, 2023
1fb5ad7
Merge branch 'master' into feature/sacrebleu-tokenizers
mergify[bot] Oct 4, 2023
78f383f
Merge branch 'master' into feature/sacrebleu-tokenizers
mergify[bot] Oct 6, 2023
f9e2b1b
Merge branch 'master' into feature/sacrebleu-tokenizers
mergify[bot] Oct 6, 2023
d7754ab
Merge branch 'master' into feature/sacrebleu-tokenizers
mergify[bot] Oct 6, 2023
227d615
Merge branch 'master' into feature/sacrebleu-tokenizers
mergify[bot] Oct 6, 2023
4f5af8b
Merge branch 'master' into feature/sacrebleu-tokenizers
mergify[bot] Oct 8, 2023
e286eaf
Merge branch 'master' into feature/sacrebleu-tokenizers
mergify[bot] Oct 9, 2023
2b29c91
skip on windows
SkafteNicki Oct 9, 2023
47b5030
Merge branch 'master' into feature/sacrebleu-tokenizers
mergify[bot] Oct 9, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .github/workflows/ci-integrate.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,14 @@ jobs:
strategy:
fail-fast: false
matrix:
os: [ubuntu-20.04, macOS-11, windows-2022]
os: [ubuntu-20.04, macOS-12, windows-2022]
python-version: ["3.8", "3.10"]
requires: ["oldest", "latest"]
exclude:
- { python-version: "3.10", requires: "oldest" }
- { python-version: "3.10", os: "windows" } # todo: https://discuss.pytorch.org/t/numpy-is-not-available-error/146192
include:
- { python-version: "3.10", requires: "latest", os: "ubuntu-22.04" }
- { python-version: "3.10", requires: "latest", os: "macOS-12" }
env:
PYTORCH_URL: "https://download.pytorch.org/whl/cpu/torch_stable.html"
FREEZE_REQUIREMENTS: ${{ ! (github.ref == 'refs/heads/master' || startsWith(github.ref, 'refs/heads/release/')) }}
Expand Down
8 changes: 4 additions & 4 deletions .github/workflows/ci-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ jobs:
- { os: "ubuntu-22.04", python-version: "3.10", pytorch-version: "1.13.1" }
- { os: "ubuntu-22.04", python-version: "3.10", pytorch-version: "2.0.0" }
- { os: "ubuntu-22.04", python-version: "3.11", pytorch-version: "2.0.0" }
- { os: "macOS-11", python-version: "3.8", pytorch-version: "1.13.1" }
- { os: "macOS-11", python-version: "3.9", pytorch-version: "1.13.1" }
- { os: "macOS-11", python-version: "3.10", pytorch-version: "2.0.0" }
- { os: "macOS-11", python-version: "3.11", pytorch-version: "2.0.0" }
- { os: "macOS-12", python-version: "3.8", pytorch-version: "1.13.1" }
- { os: "macOS-12", python-version: "3.9", pytorch-version: "1.13.1" }
- { os: "macOS-12", python-version: "3.10", pytorch-version: "2.0.0" }
- { os: "macOS-12", python-version: "3.11", pytorch-version: "2.0.0" }
- { os: "windows-2022", python-version: "3.8", pytorch-version: "1.13.1" }
- { os: "windows-2022", python-version: "3.9", pytorch-version: "1.13.1" }
- { os: "windows-2022", python-version: "3.10", pytorch-version: "2.0.0" }
Expand Down
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

**Note: we move fast, but still we preserve 0.1 version (one feature release) back compatibility.**


## [UnReleased] - 2023-MM-DD

### Added

- Added more tokenizers for `SacreBLEU` metric ([#2068](https://github.com/Lightning-AI/torchmetrics/pull/2068))


- Added `average` argument to multiclass versions of `PrecisionRecallCurve` and `ROC` ([#2084](https://github.com/Lightning-AI/torchmetrics/pull/2084))

- Added error if `NoTrainInceptionV3` is being initialized without `torch-fidelity` not being installed ([#2143](https://github.com/Lightning-AI/torchmetrics/pull/2143))
Expand Down
2 changes: 2 additions & 0 deletions docs/source/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -164,4 +164,6 @@
.. _Completeness Score: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.completeness_score.html
.. _Davies-Bouldin Score: https://en.wikipedia.org/wiki/Davies%E2%80%93Bouldin_index
.. _Fowlkes-Mallows Index: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.fowlkes_mallows_score.html#sklearn.metrics.fowlkes_mallows_score
.. _FLORES-101: https://arxiv.org/abs/2106.03193
.. _FLORES-200: https://arxiv.org/abs/2207.04672
.. _averaging curve objects: https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html
5 changes: 5 additions & 0 deletions requirements/text.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,8 @@ nltk >=3.6, <=3.8.1
tqdm >=4.41.0, <=4.66.1
regex >=2021.9.24, <=2023.8.8
transformers >4.4.0, <4.30.3
mecab-python3 >=1.0.6, <1.1.0
mecab-ko >=1.0.0, <1.1.0
mecab-ko-dic >=1.0.0, <1.1.0
ipadic >=1.0.0, <1.1.0
sentencepiece >=0.1.98, <=0.1.99
2 changes: 1 addition & 1 deletion requirements/text_test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ jiwer >=2.3.0, <3.1.0
rouge-score >0.1.0, <=0.1.2
bert_score ==0.3.13
huggingface-hub <0.18 # hotfix, failing SDR for latest PT 1.11
sacrebleu >=2.0.0, <=2.3.1
sacrebleu >=2.3.0, <=2.3.1
199 changes: 177 additions & 22 deletions src/torchmetrics/functional/text/sacre_bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,28 @@
# MIT License
# Copyright (c) 2017 - Shujian Huang <[email protected]>

import os
import re
import tempfile
from functools import partial
from typing import ClassVar, Optional, Sequence
from typing import Any, ClassVar, Dict, Optional, Sequence

import torch
from torch import Tensor, tensor
from typing_extensions import Literal

from torchmetrics.functional.text.bleu import _bleu_score_compute, _bleu_score_update
from torchmetrics.utilities.imports import _REGEX_AVAILABLE
from torchmetrics.utilities.imports import (
_IPADIC_AVAILABLE,
_MECAB_AVAILABLE,
_MECAB_KO_AVAILABLE,
_MECAB_KO_DIC_AVAILABLE,
_REGEX_AVAILABLE,
_SENTENCEPIECE_AVAILABLE,
)

AVAILABLE_TOKENIZERS = ("none", "13a", "zh", "intl", "char")
AVAILABLE_TOKENIZERS = ("none", "13a", "zh", "intl", "char", "ja-mecab", "ko-mecab", "flores101", "flores200")
_Tokenizers_list = Literal["none", "13a", "zh", "intl", "char", "ja-mecab", "ko-mecab", "flores101", "flores200"]

_UCODE_RANGES = (
("\u3400", "\u4db5"), # CJK Unified Ideographs Extension A, release 3.0
Expand Down Expand Up @@ -77,6 +87,14 @@
)


_FLORES_LOCAL_DIR = os.path.join(tempfile.gettempdir(), "torchmetrics-flores")
# Model paths copied from https://github.com/mjpost/sacrebleu/blob/master/sacrebleu/tokenizers/tokenizer_spm.py.
_FLORES_MODELS_URL = {
"flores101": "https://dl.fbaipublicfiles.com/fairseq/models/flores/sacrebleu_tokenizer_spm.model",
"flores200": "https://tinyurl.com/flores200sacrebleuspm",
}


class _SacreBLEUTokenizer:
"""Tokenizer used for SacreBLEU calculation.

Expand Down Expand Up @@ -116,9 +134,18 @@ class _SacreBLEUTokenizer:
"zh": "_tokenize_zh",
"intl": "_tokenize_international",
"char": "_tokenize_char",
"ja-mecab": "_tokenize_ja_mecab",
"ko-mecab": "_tokenize_ko_mecab",
"flores101": "_tokenize_flores_101",
"flores200": "_tokenize_flores_200",
}

def __init__(self, tokenize: Literal["none", "13a", "zh", "intl", "char"], lowercase: bool = False) -> None:
# Keep it as class variable to avoid initializing over and over again
sentencepiece_processors: ClassVar[Dict[str, Optional[Any]]] = {"flores101": None, "flores200": None}

def __init__(self, tokenize: _Tokenizers_list, lowercase: bool = False) -> None:
self._check_tokenizers_validity(tokenize)

self.tokenize_fn = getattr(self, self._TOKENIZE_FN[tokenize])
self.lowercase = lowercase

Expand All @@ -127,9 +154,9 @@ def __call__(self, line: str) -> Sequence[str]:
return self._lower(tokenized_line, self.lowercase).split()

@classmethod
def tokenize(
cls, line: str, tokenize: Literal["none", "13a", "zh", "intl", "char"], lowercase: bool = False
) -> Sequence[str]:
def tokenize(cls, line: str, tokenize: _Tokenizers_list, lowercase: bool = False) -> Sequence[str]:
cls._check_tokenizers_validity(tokenize)

tokenize_fn = getattr(cls, cls._TOKENIZE_FN[tokenize])
tokenized_line = tokenize_fn(line)
return cls._lower(tokenized_line, lowercase).split()
Expand Down Expand Up @@ -274,19 +301,159 @@ def _tokenize_char(cls, line: str) -> str:
"""
return " ".join(char for char in line)

@classmethod
def _tokenize_ja_mecab(cls, line: str) -> str:
"""Tokenizes a Japanese string line using MeCab morphological analyzer.

Args:
line: the input string to tokenize.

Return:
The tokenized string.

"""
import ipadic
import MeCab

tagger = MeCab.Tagger(ipadic.MECAB_ARGS + " -Owakati")

line = line.strip()
return tagger.parse(line).strip()

@classmethod
def _tokenize_ko_mecab(cls, line: str) -> str:
"""Tokenizes a Korean string line using MeCab-korean morphological analyzer.

Args:
line: the input string to tokenize.

Return:
The tokenized string.

"""
import mecab_ko
import mecab_ko_dic

tagger = mecab_ko.Tagger(mecab_ko_dic.MECAB_ARGS + " -Owakati")

line = line.strip()
return tagger.parse(line).strip()

@classmethod
def _tokenize_flores(cls, line: str, tokenize: Literal["flores101", "flores200"]) -> str:
"""Tokenizes a string line using sentencepiece tokenizer.

Args:
line: the input string to tokenize.
tokenize: Tokenization technique to be used.

Return:
The tokenized string.

"""
import sentencepiece

if cls.sentencepiece_processors[tokenize] is None:
cls.sentencepiece_processors[tokenize] = sentencepiece.SentencePieceProcessor()

file_path = os.path.join(_FLORES_LOCAL_DIR, _FLORES_MODELS_URL[tokenize].split("/")[-1])
if not os.path.exists(file_path):
cls.download_flores_file(tokenize)

cls.sentencepiece_processors[tokenize].Load(file_path) # type: ignore[union-attr]

return " ".join(cls.sentencepiece_processors[tokenize].EncodeAsPieces(line)) # type: ignore[union-attr]

@classmethod
def _tokenize_flores_101(cls, line: str) -> str:
"""Tokenizes a string line using sentencepiece tokenizer according to `FLORES-101`_ dataset.

Args:
line: the input string to tokenize.

Return:
The tokenized string.

"""
return cls._tokenize_flores(line, "flores101")

@classmethod
def _tokenize_flores_200(cls, line: str) -> str:
"""Tokenizes a string line using sentencepiece tokenizer according to `FLORES-200`_ dataset.

Args:
line: the input string to tokenize.

Return:
The tokenized string.

"""
return cls._tokenize_flores(line, "flores200")

@staticmethod
def _lower(line: str, lowercase: bool) -> str:
if lowercase:
return line.lower()
return line

@classmethod
def _check_tokenizers_validity(cls, tokenize: _Tokenizers_list) -> None:
"""Check if a supported tokenizer is chosen.

Also check all dependencies of a given tokenizers are installed.

"""
if tokenize not in cls._TOKENIZE_FN:
raise ValueError(f"Unsupported tokenizer selected. Please, choose one of {list(cls._TOKENIZE_FN.keys())}")

if tokenize == "intl" and not _REGEX_AVAILABLE:
raise ModuleNotFoundError(
"`'intl'` tokenization requires that `regex` is installed."
" Use `pip install regex` or `pip install torchmetrics[text]`."
)

if tokenize == "ja-mecab" and not (_MECAB_AVAILABLE and _IPADIC_AVAILABLE):
raise ModuleNotFoundError(
"`'ja-mecab'` tokenization requires that `MeCab` and `ipadic` are installed."
" Use `pip install mecab-python3 ipadic` or `pip install torchmetrics[text]`."
)

if tokenize == "ko-mecab" and not (_MECAB_KO_AVAILABLE and _MECAB_KO_DIC_AVAILABLE):
raise ModuleNotFoundError(
"`'ko-mecab'` tokenization requires that `mecab_ko` and `mecab_ko_dic` are installed."
" Use `pip install mecab_ko mecab_ko_dic` or `pip install torchmetrics[text]`."
)

if "flores" in tokenize and not _SENTENCEPIECE_AVAILABLE:
raise ModuleNotFoundError(
"`'flores101' and 'flores200'` tokenizations require that `sentencepiece` is installed."
" Use `pip install sentencepiece` or `pip install torchmetrics[text]`."
)

@staticmethod
def download_flores_file(model_name: Literal["flores101", "flores200"]) -> None:
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
"""Download necessary files for `flores` tokenization via `sentencepiece`."""
import ssl
import urllib.request

os.makedirs(_FLORES_LOCAL_DIR, exist_ok=True)

model_url = _FLORES_MODELS_URL[model_name]
file_path = os.path.join(_FLORES_LOCAL_DIR, model_url.split("/")[-1])

try:
with open(file_path, "wb") as out_file, urllib.request.urlopen(model_url) as remote_file:
out_file.write(remote_file.read())
except ssl.SSLError as e:
raise OSError(f"Failed to download {model_name} model.") from e


def sacre_bleu_score(
preds: Sequence[str],
target: Sequence[Sequence[str]],
n_gram: int = 4,
smooth: bool = False,
tokenize: Literal["none", "13a", "zh", "intl", "char"] = "13a",
tokenize: _Tokenizers_list = "13a",
lowercase: bool = False,
weights: Optional[Sequence[float]] = None,
) -> Tensor:
Expand All @@ -299,8 +466,8 @@ def sacre_bleu_score(
target: An iterable of iterables of reference corpus
n_gram: Gram value ranged from 1 to 4
smooth: Whether to apply smoothing - see [2]
tokenize: Tokenization technique to be used.
Supported tokenization: ['none', '13a', 'zh', 'intl', 'char']
tokenize: Tokenization technique to be used. Choose between ``'none'``, ``'13a'``, ``'zh'``, ``'intl'``,
``'char'``, ``'ja-mecab'``, ``'ko-mecab'``, ``'flores101'`` and ``'flores200'``.
lowercase: If ``True``, BLEU score over lowercased text is calculated.
weights:
Weights used for unigrams, bigrams, etc. to calculate BLEU score.
Expand Down Expand Up @@ -330,20 +497,8 @@ def sacre_bleu_score(
and Skip-Bigram Statistics by Chin-Yew Lin and Franz Josef Och `Machine Translation Evolution`_

"""
if tokenize not in AVAILABLE_TOKENIZERS:
raise ValueError(f"Argument `tokenize` expected to be one of {AVAILABLE_TOKENIZERS} but got {tokenize}.")

if tokenize not in _SacreBLEUTokenizer._TOKENIZE_FN:
raise ValueError(
f"Unsupported tokenizer selected. Please, choose one of {list(_SacreBLEUTokenizer._TOKENIZE_FN.keys())}"
)
if len(preds) != len(target):
raise ValueError(f"Corpus has different size {len(preds)} != {len(target)}")
if tokenize == "intl" and not _REGEX_AVAILABLE:
raise ModuleNotFoundError(
"`'intl'` tokenization requires that `regex` is installed."
" Use `pip install regex` or `pip install torchmetrics[text]`."
)

if weights is not None and len(weights) != n_gram:
raise ValueError(f"List of weights has different weights than `n_gram`: {len(weights)} != {n_gram}")
Expand Down
22 changes: 5 additions & 17 deletions src/torchmetrics/text/sacre_bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,17 @@
from typing import Any, Optional, Sequence, Union

from torch import Tensor
from typing_extensions import Literal

from torchmetrics.functional.text.bleu import _bleu_score_update
from torchmetrics.functional.text.sacre_bleu import _SacreBLEUTokenizer
from torchmetrics.functional.text.sacre_bleu import _SacreBLEUTokenizer, _Tokenizers_list
from torchmetrics.text.bleu import BLEUScore
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _REGEX_AVAILABLE
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["SacreBLEUScore.plot"]


AVAILABLE_TOKENIZERS = ("none", "13a", "zh", "intl", "char")


class SacreBLEUScore(BLEUScore):
"""Calculate `BLEU score`_ of machine translated text with one or more references.

Expand All @@ -53,8 +49,8 @@ class SacreBLEUScore(BLEUScore):
Args:
n_gram: Gram value ranged from 1 to 4
smooth: Whether to apply smoothing, see `SacreBLEU`_
tokenize: Tokenization technique to be used.
Supported tokenization: ``['none', '13a', 'zh', 'intl', 'char']``
tokenize: Tokenization technique to be used. Choose between ``'none'``, ``'13a'``, ``'zh'``, ``'intl'``,
``'char'``, ``'ja-mecab'``, ``'ko-mecab'``, ``'flores101'`` and ``'flores200'``.
lowercase: If ``True``, BLEU score over lowercased text is calculated.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
weights:
Expand Down Expand Up @@ -95,20 +91,12 @@ def __init__(
self,
n_gram: int = 4,
smooth: bool = False,
tokenize: Literal["none", "13a", "zh", "intl", "char"] = "13a",
tokenize: _Tokenizers_list = "13a",
lowercase: bool = False,
weights: Optional[Sequence[float]] = None,
**kwargs: Any,
) -> None:
super().__init__(n_gram=n_gram, smooth=smooth, weights=weights, **kwargs)
if tokenize not in AVAILABLE_TOKENIZERS:
raise ValueError(f"Argument `tokenize` expected to be one of {AVAILABLE_TOKENIZERS} but got {tokenize}.")

if tokenize == "intl" and not _REGEX_AVAILABLE:
raise ModuleNotFoundError(
"`'intl'` tokenization requires that `regex` is installed."
" Use `pip install regex` or `pip install torchmetrics[text]`."
)
self.tokenizer = _SacreBLEUTokenizer(tokenize, lowercase)

def update(self, preds: Sequence[str], target: Sequence[Sequence[str]]) -> None:
Expand Down
Loading