Skip to content

Commit

Permalink
Merge pull request #66 from ssciwr/model_manager
Browse files Browse the repository at this point in the history
add `ModelManager`
  • Loading branch information
lkeegan authored Mar 24, 2023
2 parents d092864 + 69059dd commit e5341be
Show file tree
Hide file tree
Showing 6 changed files with 409 additions and 22 deletions.
5 changes: 3 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@ on:

jobs:
test:
name: "${{ matrix.os }} :: ${{ matrix.python-version }}"
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-20.04]
os: [ubuntu-latest]
python-version: [3.9]
steps:
- name: Checkout repository
Expand All @@ -30,7 +31,7 @@ jobs:
- name: Run pytest
run: |
cd moralization
python -m pytest -s --cov=. --cov-report=xml
python -m pytest -v -s --cov=. --cov-report=xml
- name: Upload coverage
uses: codecov/codecov-action@v3
with:
Expand Down
101 changes: 101 additions & 0 deletions moralization/model_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import huggingface_hub
import spacy_huggingface_hub
import os
import spacy
from pathlib import Path
from typing import Union, Optional, Dict, Any
import tempfile
import re
import logging


def _construct_wheel_path(model_path: Path, meta: Dict[str, Any]) -> Path:
full_name = f"{meta['lang']}_{meta['name']}-{meta['version']}"
return model_path / full_name / "dist" / f"{full_name}-py3-none-any.whl"


def _make_valid_package_name(name: str) -> str:
# attempt to make name valid, throw exception if we fail
# https://packaging.python.org/en/latest/specifications/name-normalization
valid_name = re.sub(r"[-_.,<>!@#$%^&*()+ /?]+", "_", name).lower().strip("_")
if name != valid_name:
logging.warning(
f"'{name}' not a valid package name, using '{valid_name}' instead"
)
if (
re.match("^([A-Z0-9]|[A-Z0-9][A-Z0-9._-]*[A-Z0-9])$", valid_name, re.IGNORECASE)
is None
):
raise ValueError(
"Invalid package name: Can only contain ASCII letters, numbers and underscore."
)
return valid_name


class ModelManager:
"""
Import, modify and publish models to hugging face.
"""

_meta_keys_to_expose_to_user = [
"name",
"version",
"description",
"author",
"email",
"url",
"license",
]

def __init__(self, model_path: Union[str, Path] = None):
self.load(model_path)

def load(self, model_path: Union[str, Path]):
"""Load a spacy model from `model_path`."""
self.model_path = Path(model_path)
self.spacy_model = spacy.load(model_path)
self.metadata = {
k: self.spacy_model.meta.get(k, "")
for k in self._meta_keys_to_expose_to_user
}

def save(self):
"""Save any changes made to the model metadata."""
self._update_metadata()
self.spacy_model.to_disk(self.model_path)

def publish(self, hugging_face_token: Optional[str] = None) -> Dict[str, str]:
"""Publish the model to Hugging Face.
This requires a User Access Token from https://huggingface.co/
The token can either be passed via the `hugging_face_token` argument,
or it can be set via the `HUGGING_FACE_TOKEN` environment variable.
Args:
hugging_face_token (str, optional): Hugging Face User Access Token
Returns:
dict: URLs of the published model and the pip-installable wheel
"""
self.save()
if hugging_face_token is None:
hugging_face_token = os.environ.get("HUGGING_FACE_TOKEN")
if hugging_face_token is None:
raise ValueError(
"API TOKEN required: pass as string or set the HUGGING_FACE_TOKEN environment variable."
)
huggingface_hub.login(token=hugging_face_token)
with tempfile.TemporaryDirectory() as tmpdir:
# convert model to a python package incl binary wheel
output_path = Path(tmpdir)
spacy.cli.package(self.model_path, output_path, create_wheel=True)
# push the package to hugging face
return spacy_huggingface_hub.push(
_construct_wheel_path(output_path, self.spacy_model.meta)
)

def _update_metadata(self):
self.metadata["name"] = _make_valid_package_name(self.metadata.get("name"))
for k, v in self.metadata.items():
if k in self.spacy_model.meta:
self.spacy_model.meta[k] = v
49 changes: 31 additions & 18 deletions moralization/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,44 @@
import pytest
from moralization import input_data
from moralization.data_manager import DataManager
import pathlib


def _data_path_fixture(dir_path):
@pytest.fixture
def _fixture():
return dir_path
@pytest.fixture(scope="session")
def data_dir():
return pathlib.Path(__file__).parents[1].resolve() / "data"

return _fixture

@pytest.fixture(scope="session")
def ts_file(data_dir):
return data_dir / "TypeSystem.xml"

def _doc_dict_fixture(dir_path):
@pytest.fixture
def _fixture():
return input_data.InputOutput.read_data(dir_path)

return _fixture
@pytest.fixture(scope="session")
def data_file(data_dir):
return (
data_dir / "test_data-trimmed_version_of-Interviews-pos-SH-neu-optimiert-AW.xmi"
)


dir_path = pathlib.Path(__file__).parents[1].resolve() / "data"
data_dir = _data_path_fixture(dir_path)
doc_dicts = _doc_dict_fixture(dir_path)
@pytest.fixture(scope="session")
def config_file(data_dir):
return data_dir / "config.cfg"


ts_file = _data_path_fixture(dir_path / "TypeSystem.xml")
data_file = _data_path_fixture(
dir_path / "test_data-trimmed_version_of-Interviews-pos-SH-neu-optimiert-AW.xmi"
)
config_file = _data_path_fixture(dir_path / "config.cfg")
@pytest.fixture(scope="session")
def model_path(data_dir, config_file, tmp_path_factory) -> pathlib.Path:
"""
Returns a temporary path containing a trained model.
This is only created once and re-used for the entire pytest session.
"""
dm = DataManager(data_dir)
dm.export_data_DocBin()
tmp_path = tmp_path_factory.mktemp("model")
dm.spacy_train(working_dir=tmp_path, config=config_file, n_epochs=1)
yield tmp_path / "output" / "model-best"


@pytest.fixture
def doc_dicts(data_dir):
return input_data.InputOutput.read_data(str(data_dir))
100 changes: 100 additions & 0 deletions moralization/tests/test_model_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from moralization.model_manager import ModelManager
import spacy
import pytest
import spacy_huggingface_hub
import huggingface_hub
from typing import Any
from pathlib import Path


def test_model_manager_valid_path(model_path):
model = ModelManager(model_path)
assert model.spacy_model is not None
assert model.spacy_model.lang == "de"
assert model.spacy_model.path == model_path


def test_model_manager_modify_metadata(model_path):
model = ModelManager(model_path)
# update metadata values and save model
keys = ["name", "version", "description", "author", "email", "url", "license"]
for key in keys:
model.metadata[key] = f"{key}"
model.save()
for key in keys:
assert model.metadata[key] == f"{key}"
# re-load model
model.load(model_path)
for key in keys:
assert model.metadata[key] == f"{key}"
# load model directly in spacy and check its meta has also been updated
nlp = spacy.load(model_path)
for key in keys:
assert nlp.meta[key] == f"{key}"


def test_model_manager_modify_metadata_fixable_invalid_names(model_path):
model = ModelManager(model_path)
for invalid_name, valid_name in [("!hm & __OK?,...", "hm_ok"), ("Im - S", "im_s")]:
model.metadata["name"] = invalid_name
assert model.metadata["name"] == invalid_name
# name is made valid on call to save()
model.save()
assert model.metadata["name"] == valid_name
nlp = spacy.load(model_path)
assert nlp.meta["name"] == valid_name


def test_model_manager_modify_metadata_unfixable_invalid_names(model_path):
model = ModelManager(model_path)
for unfixable_invalid_name in ["", "_", "ü"]:
model.metadata["name"] = unfixable_invalid_name
with pytest.raises(ValueError) as e:
model.save()
assert "invalid" in str(e.value).lower()


def test_model_manager_publish_no_token(model_path, monkeypatch):
monkeypatch.delenv("HUGGING_FACE_TOKEN", raising=False)
model = ModelManager(model_path)
with pytest.raises(ValueError) as e:
model.publish()
assert "token" in str(e.value).lower()


def test_model_manager_publish_invalid_token_env(model_path, monkeypatch):
monkeypatch.setenv("HUGGING_FACE_TOKEN", "invalid")
model = ModelManager(model_path)
with pytest.raises(ValueError) as e:
model.publish()
assert "token" in str(e.value).lower()


def test_model_manager_publish_invalid_token_arg(model_path):
model = ModelManager(model_path)
with pytest.raises(ValueError) as e:
model.publish(hugging_face_token="invalid")
assert "token" in str(e.value).lower()


def test_model_manager_publish_mock_push(model_path: Path, monkeypatch, tmp_path):
def mock_spacy_huggingface_hub_push(whl_path: Path):
whl_path.rename(tmp_path / whl_path.name)
return {}

# monkey patch spacy_huggingface_hub.push() to just move the supplied wheel to a temporary path
monkeypatch.setattr(spacy_huggingface_hub, "push", mock_spacy_huggingface_hub_push)

def do_nothing(*args: Any, **kwargs: Any) -> None:
return

# monkey patch huggingface_hub.login() to do nothing
monkeypatch.setattr(huggingface_hub, "login", do_nothing)

model = ModelManager(model_path)
# set name and version - these determine the name of the compiled wheel
model.metadata["name"] = "my_new_pipeline"
model.metadata["version"] = "1.2.3"
model.publish(hugging_face_token="abc123")
wheel_path = tmp_path / "de_my_new_pipeline-1.2.3-py3-none-any.whl"
assert wheel_path.is_file()
Loading

0 comments on commit e5341be

Please sign in to comment.