From a8dac088b76da3a4f159c78d5958d04d351cddd0 Mon Sep 17 00:00:00 2001 From: Stefan Schweter Date: Mon, 16 Nov 2020 14:00:04 +0100 Subject: [PATCH 1/3] models: add support for Hugging Face model hub --- flair/models/sequence_tagger_model.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/flair/models/sequence_tagger_model.py b/flair/models/sequence_tagger_model.py index b7f82d0705..8ae5e9d715 100644 --- a/flair/models/sequence_tagger_model.py +++ b/flair/models/sequence_tagger_model.py @@ -1104,6 +1104,24 @@ def _fetch_model(model_name) -> str: unzip_file(Path(flair.cache_root) / cache_dir / 'freeIndirect.zip', Path(flair.cache_root) / cache_dir) model_name = str(Path(flair.cache_root) / cache_dir / 'freeIndirect' / 'final-model.pt') + # Fallback to Hugging Face model hub + if not Path(model_name).exists() and not model_name.startswith("http"): + # e.g. stefan-it/flair-ner-conll03 is a valid namespace + # and stefan-it/flair-ner-conll03@main supports specifying a commit/branch name + hf_model_name = "model.bin" + revision = "main" + + if "@" in model_name: + model_name_splitted = model_name.split("@") + revision = model_name_splitted[-1] + model_name = model_name_splitted[0] + + # Lazy import + from transformers import file_utils + + url = file_utils.hf_bucket_url(model_id=model_name, revision=revision, filename=hf_model_name) + model_name = file_utils.cached_path(url_or_filename=url) + return model_name def get_transition_matrix(self): From 0e9330194733f8a43821539c0c20854a5cdb4886 Mon Sep 17 00:00:00 2001 From: Stefan Schweter Date: Mon, 23 Nov 2020 12:50:31 +0100 Subject: [PATCH 2/3] models: use flair cache dir when using flair-models from Hugging Face model hub --- flair/models/sequence_tagger_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flair/models/sequence_tagger_model.py b/flair/models/sequence_tagger_model.py index 8ae5e9d715..4477885c9f 100644 --- a/flair/models/sequence_tagger_model.py +++ b/flair/models/sequence_tagger_model.py @@ -1120,7 +1120,7 @@ def _fetch_model(model_name) -> str: from transformers import file_utils url = file_utils.hf_bucket_url(model_id=model_name, revision=revision, filename=hf_model_name) - model_name = file_utils.cached_path(url_or_filename=url) + model_name = file_utils.cached_path(url_or_filename=url, cache_dir=flair.cache_root) return model_name From cd56d1aa9e70e7a4219cbe7c2a8d2c2a5ea7ef1f Mon Sep 17 00:00:00 2001 From: Stefan Schweter Date: Wed, 25 Nov 2020 10:15:05 +0100 Subject: [PATCH 3/3] pip: pin Transformers version (3.5.0 and 3.5.1) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index a072dd28be..820e42fcd0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,7 +9,7 @@ scikit-learn>=0.21.3 sqlitedict>=1.6.0 deprecated>=1.2.4 hyperopt>=0.1.1 -transformers>=3.0.0 +transformers>=3.5.0,<=3.5.1 bpemb>=0.3.2 regex tabulate