Skip to content

Commit

Permalink
transformers changed model save format (#151)
Browse files Browse the repository at this point in the history
  • Loading branch information
iulusoy authored May 15, 2024
1 parent 8f167a3 commit 4d8348d
Show file tree
Hide file tree
Showing 10 changed files with 90 additions and 87 deletions.
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
[flake8]
extend-ignore = E203
extend-ignore = E122, E203, E231, W604
exclude = .git,__pycache__,.ipynb_checkpoints
max-line-length = 120
2 changes: 1 addition & 1 deletion .flake8_nb
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[flake8_nb]
notebook_cell_format = '{nb_path}:code_cell#{code_cell_count}'
extend-ignore = E203, E402
extend-ignore = E122, E203, E231, E402, W604, E713
exclude = .git,__pycache__,.ipynb_checkpoints
max-line-length = 120
max-complexity = 18
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
repos:
- repo: https://github.com/kynan/nbstripout
rev: 0.6.1
rev: 0.7.1
hooks:
- id: nbstripout
files: ".ipynb"
- repo: https://github.com/psf/black
rev: 23.9.1
rev: 24.4.2
hooks:
- id: black
- repo: https://github.com/pycqa/flake8
rev: 6.1.0
rev: 7.0.0
hooks:
- id: flake8
- repo: https://github.com/s-weigand/flake8-nb
Expand Down
1 change: 1 addition & 0 deletions moralization/analyse.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Contains statistical analysis.
"""

from collections import defaultdict
import pandas as pd
import numpy as np
Expand Down
1 change: 1 addition & 0 deletions moralization/input_data.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Module that handles input reading.
"""

from cassis import load_typesystem, load_cas_from_xmi, typesystem, Cas
import pathlib
import importlib_resources
Expand Down
3 changes: 2 additions & 1 deletion moralization/plot.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Contains plotting functionality.
"""

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -72,7 +73,7 @@ def _get_filter_multiindex(occurrence_df: pd.DataFrame, filters):
filter_dict["sub"].append(_filter)

else:
raise KeyError(f"Filter key: `{ _filter}` not in dataframe columns.")
raise KeyError(f"Filter key: `{_filter}` not in dataframe columns.")

if filter_dict["main"] == []:
filter_dict["main"] = slice(None)
Expand Down
8 changes: 4 additions & 4 deletions moralization/tests/test_transformers_model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,14 +264,11 @@ def test_train_evaluate(gen_instance, gen_instance_dm):
learning_rate,
)
assert gen_instance.results["overall_precision"] == pytest.approx(0.0, 1e-3)
assert (model_path / "pytorch_model.bin").is_file()
assert (model_path / "model.safetensors").is_file()
assert (model_path / "special_tokens_map.json").is_file()
assert (model_path / "config.json").is_file()
evaluate_result = gen_instance.evaluate("Python ist toll.")
assert evaluate_result[0]["score"]
del gen_instance._model_path
with pytest.raises(ValueError):
gen_instance.evaluate("Python ist toll.")
# check that column names throw error if not given correctly
label_column_name = "something"
with pytest.raises(ValueError):
Expand All @@ -291,6 +288,9 @@ def test_train_evaluate(gen_instance, gen_instance_dm):
num_train_epochs,
learning_rate,
)
del gen_instance._model_path
with pytest.raises(ValueError):
gen_instance.evaluate("Python ist toll.")


def test_publish(gen_instance):
Expand Down
2 changes: 1 addition & 1 deletion moralization/transformers_model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,7 @@ def _evaluate_model(self):
def save(self):
"""Save the model to the set model path.
If a model already exists in that path, it will be overwritten."""
model_file = self.model_path / "pytorch_model.bin"
model_file = self.model_path / "model.safetensors"
if model_file.exists():
print(
"Model file already existing at specified model path {} - will be overwritten.".format(
Expand Down
60 changes: 30 additions & 30 deletions notebooks/DemoNotebook_spacy_model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "markdown",
"id": "17ddce3f",
"id": "0",
"metadata": {
"id": "-tIt14wg_KRi"
},
Expand All @@ -17,7 +17,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "dec6ba0b",
"id": "1",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -42,7 +42,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "f7ce3ad0",
"id": "2",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -51,7 +51,7 @@
},
{
"cell_type": "markdown",
"id": "a0aa6775",
"id": "3",
"metadata": {},
"source": [
"### Import training data using DataManager\n",
Expand All @@ -64,7 +64,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "fa65a7f6",
"id": "4",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -81,7 +81,7 @@
},
{
"cell_type": "markdown",
"id": "d0ebdd8c",
"id": "5",
"metadata": {},
"source": [
"You can provide the language in the model initialization, as well as the task that should be trained on. Default language is German, and the default task is task1: Training for labels in category 1. The available tasks are:\n",
Expand All @@ -97,7 +97,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "5508ae9c",
"id": "6",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -118,7 +118,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "44031c8d",
"id": "7",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -128,7 +128,7 @@
},
{
"cell_type": "markdown",
"id": "a988aa6b",
"id": "8",
"metadata": {},
"source": [
"### Create a new spacy model using ModelManager\n",
Expand All @@ -153,7 +153,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "b868d3d7",
"id": "9",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -167,22 +167,22 @@
},
{
"cell_type": "markdown",
"id": "c7d2c783",
"id": "10",
"metadata": {},
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "a16b05d5",
"id": "11",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "0ec9f981",
"id": "12",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -191,7 +191,7 @@
},
{
"cell_type": "markdown",
"id": "e6cb68cb",
"id": "13",
"metadata": {},
"source": [
"### Edit metadata\n",
Expand All @@ -204,7 +204,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "3769b760",
"id": "14",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -214,7 +214,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "9701fcdb",
"id": "15",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -230,7 +230,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "5f5172b5",
"id": "16",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -240,7 +240,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "961d1d2f",
"id": "17",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -249,7 +249,7 @@
},
{
"cell_type": "markdown",
"id": "a0059ffd",
"id": "18",
"metadata": {},
"source": [
"### Train the model\n",
Expand All @@ -261,7 +261,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "576bdb58",
"id": "19",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -270,7 +270,7 @@
},
{
"cell_type": "markdown",
"id": "a105a00a",
"id": "20",
"metadata": {
"tags": []
},
Expand All @@ -284,7 +284,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "6a10dd79",
"id": "21",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -293,7 +293,7 @@
},
{
"cell_type": "markdown",
"id": "42d60f85",
"id": "22",
"metadata": {},
"source": [
"### Test the model"
Expand All @@ -302,7 +302,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "e6ed2d40",
"id": "23",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -311,7 +311,7 @@
},
{
"cell_type": "markdown",
"id": "f1c99e30",
"id": "24",
"metadata": {},
"source": [
"### Publish model to hugging-face\n",
Expand All @@ -324,7 +324,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "18b270dd",
"id": "25",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -334,7 +334,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "2f36af3f",
"id": "26",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -343,7 +343,7 @@
},
{
"cell_type": "markdown",
"id": "c0b7c1f0",
"id": "27",
"metadata": {},
"source": [
"### Load an existing model\n",
Expand All @@ -354,7 +354,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "d9706336",
"id": "28",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -365,7 +365,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "70f521ea",
"id": "29",
"metadata": {},
"outputs": [],
"source": []
Expand Down
Loading

0 comments on commit 4d8348d

Please sign in to comment.