diff --git a/python/openvino_tokenizers/cli.py b/python/openvino_tokenizers/cli.py index ded5de9c..11501f9a 100644 --- a/python/openvino_tokenizers/cli.py +++ b/python/openvino_tokenizers/cli.py @@ -237,11 +237,14 @@ def get_parser() -> ArgumentParser: "--utf8_replace_mode", choices=list(UTF8ReplaceMode), type=UTF8ReplaceMode, # enum with 'ignore', 'replace' values. - default=None, + default=UTF8ReplaceMode.REPLACE, required=False, help=( "If specified then resulting strings during decoding are checked if sequence of bytes is a valid UTF-8 sequence. " - f"If mode is '{UTF8ReplaceMode.REPLACE}' then invalid characters are replaced with �, if mode is '{UTF8ReplaceMode.IGNORE}' then invalid character are skipped." + f"If mode is '{UTF8ReplaceMode.DISABLE}' then UTF8 validation is not performed at all. " + f"Two other regimes are identical to python decode method error handling parameter. " + f"If mode is '{UTF8ReplaceMode.REPLACE}' then invalid characters are replaced with �. " + f"if mode is '{UTF8ReplaceMode.IGNORE}' then invalid character are skipped and instead of them empty substring is added." ), ) return parser diff --git a/python/openvino_tokenizers/constants.py b/python/openvino_tokenizers/constants.py index 02894bdc..e793ae51 100644 --- a/python/openvino_tokenizers/constants.py +++ b/python/openvino_tokenizers/constants.py @@ -39,6 +39,16 @@ class UTF8ReplaceMode(Enum): IGNORE: str = "ignore" REPLACE: str = "replace" + DISABLE: str = "disable" def __str__(self): return self.value + + def __eq__(self, other): + if isinstance(other, (UTF8ReplaceMode)): + # UTF8ReplaceMode is a singleton, so we can compare them by reference + return self is other + elif isinstance(other, str): + return self.value == other + else: + return False diff --git a/python/openvino_tokenizers/convert_tokenizer.py b/python/openvino_tokenizers/convert_tokenizer.py index ee51c4e3..34de5a05 100644 --- a/python/openvino_tokenizers/convert_tokenizer.py +++ b/python/openvino_tokenizers/convert_tokenizer.py @@ -71,7 +71,7 @@ def convert_tokenizer( use_max_padding: bool = False, handle_special_tokens_with_re: Optional[bool] = None, use_sentencepiece_backend: bool = False, - utf8_replace_mode: Optional[UTF8ReplaceMode] = None, + utf8_replace_mode: Optional[UTF8ReplaceMode] = UTF8ReplaceMode.REPLACE, ) -> Union[Model, Tuple[Model, Model]]: """ Converts a given tokenizer object into an OpenVINO-compatible model. diff --git a/python/openvino_tokenizers/hf_parser.py b/python/openvino_tokenizers/hf_parser.py index c5df2e7c..75d357d7 100644 --- a/python/openvino_tokenizers/hf_parser.py +++ b/python/openvino_tokenizers/hf_parser.py @@ -390,8 +390,8 @@ def decoding(self) -> None: self.pipeline.add_steps(CharsToBytesStep()) else: self.pipeline.add_steps(FuseStep()) - - if self.utf8_replace_mode is not None: + + if self.utf8_replace_mode is not None and (self.utf8_replace_mode != UTF8ReplaceMode.DISABLE): self.pipeline.add_steps(UTF8ValidateStep(mode=self.utf8_replace_mode)) if self.clean_up_tokenization_spaces is None: @@ -981,12 +981,12 @@ def get_sp_detokenizer( if params.clean_up_tokenization_spaces: detokenizer = RegexDecodingStep.clean_up_tokenization_spaces().get_ov_subgraph(detokenizer) + + last_sinks = detokenizer + if params.utf8_replace_mode is not None and params.utf8_replace_mode != UTF8ReplaceMode.DISABLE: + last_sinks = UTF8ValidateStep(params.utf8_replace_mode).get_ov_subgraph(detokenizer) - if params.utf8_replace_mode is not None: - replace_mode = True if params.utf8_replace_mode is UTF8ReplaceMode.REPLACE else False - UTF8ValidateStep(mode=replace_mode).get_ov_subgraph(detokenizer) - - string_output = _get_factory().create("StringTensorPack", detokenizer).outputs() + string_output = _get_factory().create("StringTensorPack", last_sinks).outputs() string_output[0].tensor.add_names({STRING_OUTPUT_NAME}) tokenizer_detokenizer = Model(string_output, [model_input], DETOKENIZER_NAME) tokenizer_detokenizer.validate_nodes_and_infer_types() diff --git a/python/openvino_tokenizers/tokenizer_pipeline.py b/python/openvino_tokenizers/tokenizer_pipeline.py index ab8926a7..52417a3b 100644 --- a/python/openvino_tokenizers/tokenizer_pipeline.py +++ b/python/openvino_tokenizers/tokenizer_pipeline.py @@ -1043,10 +1043,10 @@ def get_ov_subgraph(self, input_nodes: List[Output]) -> List[Output]: @dataclass class UTF8ValidateStep(DecodingStep): - mode: UTF8ReplaceMode = UTF8ReplaceMode.IGNORE + mode: UTF8ReplaceMode = field(default_factory=lambda: UTF8ReplaceMode.IGNORE) def get_ov_subgraph(self, input_nodes: List[Output]) -> List[Output]: - replace_mode = True if self.mode is UTF8ReplaceMode.REPLACE else False + replace_mode = True if self.mode == UTF8ReplaceMode.REPLACE else False return _get_factory().create("UTF8Validate", input_nodes, {"replace_mode": replace_mode}).outputs() diff --git a/python/openvino_tokenizers/utils.py b/python/openvino_tokenizers/utils.py index 6b483a8b..ea1027e8 100644 --- a/python/openvino_tokenizers/utils.py +++ b/python/openvino_tokenizers/utils.py @@ -4,7 +4,7 @@ import logging import re -from dataclasses import dataclass, fields +from dataclasses import dataclass, fields, field from functools import lru_cache from typing import Any, Dict, Optional, Sequence, Tuple, Union @@ -57,7 +57,7 @@ class TokenzierConversionParams: utf8_replace_mode : Optional[UTF8ReplaceMode] Specifies the UTF-8 replacement mode during tokenization. - Allowed values are UTF8ReplaceMode.IGNORE and UTF8ReplaceMode.REPLACE. Default is None. + Allowed values are UTF8ReplaceMode.DISABLE, UTF8ReplaceMode.IGNORE and UTF8ReplaceMode.REPLACE. Default is UTF8ReplaceMode.REPLACE. """ with_detokenizer: bool = False @@ -70,7 +70,7 @@ class TokenzierConversionParams: use_max_padding: bool = False handle_special_tokens_with_re: Optional[bool] = None use_sentencepiece_backend: bool = False - utf8_replace_mode: Optional[UTF8ReplaceMode] = None + utf8_replace_mode: Optional[UTF8ReplaceMode] = field(default_factory=lambda: UTF8ReplaceMode.REPLACE) add_attention_mask: bool = True add_prefix_space: Optional[bool] = None number_of_inputs: int = 1 diff --git a/tests/layer_tests.py b/tests/layer_tests.py index e625f636..17ed84f4 100644 --- a/tests/layer_tests.py +++ b/tests/layer_tests.py @@ -82,9 +82,7 @@ def create_normalization_model(layer: Union[NormalizationStep, DecodingStep]) -> @pytest.mark.parametrize("test_string", utf8_validate_strings) @pytest.mark.parametrize("replace_mode", ["ignore", "replace"]) def test_utf8_validate(test_string, replace_mode): - utf_validation_node = UTF8ValidateStep( - UTF8ReplaceMode.REPLACE if replace_mode == "replace" else UTF8ReplaceMode.IGNORE - ) + utf_validation_node = UTF8ValidateStep(UTF8ReplaceMode(replace_mode)) compiled_model = create_normalization_model(utf_validation_node) res_ov = compiled_model([test_string])[0] res_py = test_string.decode(errors=replace_mode)