From 744cc68ad7c7ace1c6946d20cae7c023623cad78 Mon Sep 17 00:00:00 2001 From: John Lambert Date: Fri, 8 Nov 2024 13:22:50 -0500 Subject: [PATCH 1/2] Fix moses punctuation --- machine/translation/huggingface/hugging_face_nmt_engine.py | 4 ++-- .../huggingface/hugging_face_nmt_model_trainer.py | 6 ++---- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/machine/translation/huggingface/hugging_face_nmt_engine.py b/machine/translation/huggingface/hugging_face_nmt_engine.py index 9d43b28..4da476c 100644 --- a/machine/translation/huggingface/hugging_face_nmt_engine.py +++ b/machine/translation/huggingface/hugging_face_nmt_engine.py @@ -55,8 +55,8 @@ def __init__( self._tokenizer = AutoTokenizer.from_pretrained(self._model.name_or_path, use_fast=True) if isinstance(self._tokenizer, (NllbTokenizer, NllbTokenizerFast)): self._mpn = MosesPunctNormalizer() - self._mpn.substitutions = [ - (str(re.compile(r)), sub) + self._mpn.substitutions = [ # type: ignore + (re.compile(r), sub) for r, sub in self._mpn.substitutions if isinstance(r, str) and isinstance(sub, str) ] diff --git a/machine/translation/huggingface/hugging_face_nmt_model_trainer.py b/machine/translation/huggingface/hugging_face_nmt_model_trainer.py index f15ba4e..1192243 100644 --- a/machine/translation/huggingface/hugging_face_nmt_model_trainer.py +++ b/machine/translation/huggingface/hugging_face_nmt_model_trainer.py @@ -100,10 +100,8 @@ def __init__( self._add_unk_src_tokens = add_unk_src_tokens self._add_unk_tgt_tokens = add_unk_tgt_tokens self._mpn = MosesPunctNormalizer() - self._mpn.substitutions = [ - (str(re.compile(r)), sub) - for r, sub in self._mpn.substitutions - if isinstance(r, str) and isinstance(sub, str) + self._mpn.substitutions = [ # type: ignore + (re.compile(r), sub) for r, sub in self._mpn.substitutions if isinstance(r, str) and isinstance(sub, str) ] self._stats = TrainStats() From 5e2ac4721c642b19ca325f6511ca6031875ee297 Mon Sep 17 00:00:00 2001 From: John Lambert Date: Fri, 8 Nov 2024 14:45:04 -0500 Subject: [PATCH 2/2] max_number actually not there - but it's in the generation_config! --- .../translation/huggingface/hugging_face_nmt_engine.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/machine/translation/huggingface/hugging_face_nmt_engine.py b/machine/translation/huggingface/hugging_face_nmt_engine.py index 4da476c..04086af 100644 --- a/machine/translation/huggingface/hugging_face_nmt_engine.py +++ b/machine/translation/huggingface/hugging_face_nmt_engine.py @@ -236,8 +236,12 @@ def _forward(self, model_inputs, **generate_kwargs): input_tokens = model_inputs["input_tokens"] del model_inputs["input_tokens"] - generate_kwargs["min_length"] = generate_kwargs.get("min_length", self.model.config.min_length) - generate_kwargs["max_length"] = generate_kwargs.get("max_length", self.model.config.max_length) + if hasattr(self.model, "generation_config") and self.model.generation_config is not None: + config = self.model.generation_config + else: + config = self.model.config + generate_kwargs["min_length"] = generate_kwargs.get("min_length", config.min_length) + generate_kwargs["max_length"] = generate_kwargs.get("max_length", config.max_length) self.check_inputs(input_length, generate_kwargs["min_length"], generate_kwargs["max_length"]) output = self.model.generate( **model_inputs,