-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #66 from ssciwr/model_manager
add `ModelManager`
- Loading branch information
Showing
6 changed files
with
409 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
import huggingface_hub | ||
import spacy_huggingface_hub | ||
import os | ||
import spacy | ||
from pathlib import Path | ||
from typing import Union, Optional, Dict, Any | ||
import tempfile | ||
import re | ||
import logging | ||
|
||
|
||
def _construct_wheel_path(model_path: Path, meta: Dict[str, Any]) -> Path: | ||
full_name = f"{meta['lang']}_{meta['name']}-{meta['version']}" | ||
return model_path / full_name / "dist" / f"{full_name}-py3-none-any.whl" | ||
|
||
|
||
def _make_valid_package_name(name: str) -> str: | ||
# attempt to make name valid, throw exception if we fail | ||
# https://packaging.python.org/en/latest/specifications/name-normalization | ||
valid_name = re.sub(r"[-_.,<>!@#$%^&*()+ /?]+", "_", name).lower().strip("_") | ||
if name != valid_name: | ||
logging.warning( | ||
f"'{name}' not a valid package name, using '{valid_name}' instead" | ||
) | ||
if ( | ||
re.match("^([A-Z0-9]|[A-Z0-9][A-Z0-9._-]*[A-Z0-9])$", valid_name, re.IGNORECASE) | ||
is None | ||
): | ||
raise ValueError( | ||
"Invalid package name: Can only contain ASCII letters, numbers and underscore." | ||
) | ||
return valid_name | ||
|
||
|
||
class ModelManager: | ||
""" | ||
Import, modify and publish models to hugging face. | ||
""" | ||
|
||
_meta_keys_to_expose_to_user = [ | ||
"name", | ||
"version", | ||
"description", | ||
"author", | ||
"email", | ||
"url", | ||
"license", | ||
] | ||
|
||
def __init__(self, model_path: Union[str, Path] = None): | ||
self.load(model_path) | ||
|
||
def load(self, model_path: Union[str, Path]): | ||
"""Load a spacy model from `model_path`.""" | ||
self.model_path = Path(model_path) | ||
self.spacy_model = spacy.load(model_path) | ||
self.metadata = { | ||
k: self.spacy_model.meta.get(k, "") | ||
for k in self._meta_keys_to_expose_to_user | ||
} | ||
|
||
def save(self): | ||
"""Save any changes made to the model metadata.""" | ||
self._update_metadata() | ||
self.spacy_model.to_disk(self.model_path) | ||
|
||
def publish(self, hugging_face_token: Optional[str] = None) -> Dict[str, str]: | ||
"""Publish the model to Hugging Face. | ||
This requires a User Access Token from https://huggingface.co/ | ||
The token can either be passed via the `hugging_face_token` argument, | ||
or it can be set via the `HUGGING_FACE_TOKEN` environment variable. | ||
Args: | ||
hugging_face_token (str, optional): Hugging Face User Access Token | ||
Returns: | ||
dict: URLs of the published model and the pip-installable wheel | ||
""" | ||
self.save() | ||
if hugging_face_token is None: | ||
hugging_face_token = os.environ.get("HUGGING_FACE_TOKEN") | ||
if hugging_face_token is None: | ||
raise ValueError( | ||
"API TOKEN required: pass as string or set the HUGGING_FACE_TOKEN environment variable." | ||
) | ||
huggingface_hub.login(token=hugging_face_token) | ||
with tempfile.TemporaryDirectory() as tmpdir: | ||
# convert model to a python package incl binary wheel | ||
output_path = Path(tmpdir) | ||
spacy.cli.package(self.model_path, output_path, create_wheel=True) | ||
# push the package to hugging face | ||
return spacy_huggingface_hub.push( | ||
_construct_wheel_path(output_path, self.spacy_model.meta) | ||
) | ||
|
||
def _update_metadata(self): | ||
self.metadata["name"] = _make_valid_package_name(self.metadata.get("name")) | ||
for k, v in self.metadata.items(): | ||
if k in self.spacy_model.meta: | ||
self.spacy_model.meta[k] = v |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,31 +1,44 @@ | ||
import pytest | ||
from moralization import input_data | ||
from moralization.data_manager import DataManager | ||
import pathlib | ||
|
||
|
||
def _data_path_fixture(dir_path): | ||
@pytest.fixture | ||
def _fixture(): | ||
return dir_path | ||
@pytest.fixture(scope="session") | ||
def data_dir(): | ||
return pathlib.Path(__file__).parents[1].resolve() / "data" | ||
|
||
return _fixture | ||
|
||
@pytest.fixture(scope="session") | ||
def ts_file(data_dir): | ||
return data_dir / "TypeSystem.xml" | ||
|
||
def _doc_dict_fixture(dir_path): | ||
@pytest.fixture | ||
def _fixture(): | ||
return input_data.InputOutput.read_data(dir_path) | ||
|
||
return _fixture | ||
@pytest.fixture(scope="session") | ||
def data_file(data_dir): | ||
return ( | ||
data_dir / "test_data-trimmed_version_of-Interviews-pos-SH-neu-optimiert-AW.xmi" | ||
) | ||
|
||
|
||
dir_path = pathlib.Path(__file__).parents[1].resolve() / "data" | ||
data_dir = _data_path_fixture(dir_path) | ||
doc_dicts = _doc_dict_fixture(dir_path) | ||
@pytest.fixture(scope="session") | ||
def config_file(data_dir): | ||
return data_dir / "config.cfg" | ||
|
||
|
||
ts_file = _data_path_fixture(dir_path / "TypeSystem.xml") | ||
data_file = _data_path_fixture( | ||
dir_path / "test_data-trimmed_version_of-Interviews-pos-SH-neu-optimiert-AW.xmi" | ||
) | ||
config_file = _data_path_fixture(dir_path / "config.cfg") | ||
@pytest.fixture(scope="session") | ||
def model_path(data_dir, config_file, tmp_path_factory) -> pathlib.Path: | ||
""" | ||
Returns a temporary path containing a trained model. | ||
This is only created once and re-used for the entire pytest session. | ||
""" | ||
dm = DataManager(data_dir) | ||
dm.export_data_DocBin() | ||
tmp_path = tmp_path_factory.mktemp("model") | ||
dm.spacy_train(working_dir=tmp_path, config=config_file, n_epochs=1) | ||
yield tmp_path / "output" / "model-best" | ||
|
||
|
||
@pytest.fixture | ||
def doc_dicts(data_dir): | ||
return input_data.InputOutput.read_data(str(data_dir)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
from moralization.model_manager import ModelManager | ||
import spacy | ||
import pytest | ||
import spacy_huggingface_hub | ||
import huggingface_hub | ||
from typing import Any | ||
from pathlib import Path | ||
|
||
|
||
def test_model_manager_valid_path(model_path): | ||
model = ModelManager(model_path) | ||
assert model.spacy_model is not None | ||
assert model.spacy_model.lang == "de" | ||
assert model.spacy_model.path == model_path | ||
|
||
|
||
def test_model_manager_modify_metadata(model_path): | ||
model = ModelManager(model_path) | ||
# update metadata values and save model | ||
keys = ["name", "version", "description", "author", "email", "url", "license"] | ||
for key in keys: | ||
model.metadata[key] = f"{key}" | ||
model.save() | ||
for key in keys: | ||
assert model.metadata[key] == f"{key}" | ||
# re-load model | ||
model.load(model_path) | ||
for key in keys: | ||
assert model.metadata[key] == f"{key}" | ||
# load model directly in spacy and check its meta has also been updated | ||
nlp = spacy.load(model_path) | ||
for key in keys: | ||
assert nlp.meta[key] == f"{key}" | ||
|
||
|
||
def test_model_manager_modify_metadata_fixable_invalid_names(model_path): | ||
model = ModelManager(model_path) | ||
for invalid_name, valid_name in [("!hm & __OK?,...", "hm_ok"), ("Im - S", "im_s")]: | ||
model.metadata["name"] = invalid_name | ||
assert model.metadata["name"] == invalid_name | ||
# name is made valid on call to save() | ||
model.save() | ||
assert model.metadata["name"] == valid_name | ||
nlp = spacy.load(model_path) | ||
assert nlp.meta["name"] == valid_name | ||
|
||
|
||
def test_model_manager_modify_metadata_unfixable_invalid_names(model_path): | ||
model = ModelManager(model_path) | ||
for unfixable_invalid_name in ["", "_", "ü"]: | ||
model.metadata["name"] = unfixable_invalid_name | ||
with pytest.raises(ValueError) as e: | ||
model.save() | ||
assert "invalid" in str(e.value).lower() | ||
|
||
|
||
def test_model_manager_publish_no_token(model_path, monkeypatch): | ||
monkeypatch.delenv("HUGGING_FACE_TOKEN", raising=False) | ||
model = ModelManager(model_path) | ||
with pytest.raises(ValueError) as e: | ||
model.publish() | ||
assert "token" in str(e.value).lower() | ||
|
||
|
||
def test_model_manager_publish_invalid_token_env(model_path, monkeypatch): | ||
monkeypatch.setenv("HUGGING_FACE_TOKEN", "invalid") | ||
model = ModelManager(model_path) | ||
with pytest.raises(ValueError) as e: | ||
model.publish() | ||
assert "token" in str(e.value).lower() | ||
|
||
|
||
def test_model_manager_publish_invalid_token_arg(model_path): | ||
model = ModelManager(model_path) | ||
with pytest.raises(ValueError) as e: | ||
model.publish(hugging_face_token="invalid") | ||
assert "token" in str(e.value).lower() | ||
|
||
|
||
def test_model_manager_publish_mock_push(model_path: Path, monkeypatch, tmp_path): | ||
def mock_spacy_huggingface_hub_push(whl_path: Path): | ||
whl_path.rename(tmp_path / whl_path.name) | ||
return {} | ||
|
||
# monkey patch spacy_huggingface_hub.push() to just move the supplied wheel to a temporary path | ||
monkeypatch.setattr(spacy_huggingface_hub, "push", mock_spacy_huggingface_hub_push) | ||
|
||
def do_nothing(*args: Any, **kwargs: Any) -> None: | ||
return | ||
|
||
# monkey patch huggingface_hub.login() to do nothing | ||
monkeypatch.setattr(huggingface_hub, "login", do_nothing) | ||
|
||
model = ModelManager(model_path) | ||
# set name and version - these determine the name of the compiled wheel | ||
model.metadata["name"] = "my_new_pipeline" | ||
model.metadata["version"] = "1.2.3" | ||
model.publish(hugging_face_token="abc123") | ||
wheel_path = tmp_path / "de_my_new_pipeline-1.2.3-py3-none-any.whl" | ||
assert wheel_path.is_file() |
Oops, something went wrong.