Skip to content

Commit

Permalink
Improve code quality (#147)
Browse files Browse the repository at this point in the history
  • Loading branch information
iulusoy authored Oct 4, 2023
1 parent 69a8acf commit dc30dbc
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 8 deletions.
8 changes: 4 additions & 4 deletions moralization/tests/test_transformers_model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,9 +227,9 @@ def test_load_dataloader(gen_instance, train_test_dataset):
def test_load_optimizer(gen_instance):
gen_instance.label_names = ["A", "B", "C"]
gen_instance._load_optimizer(learning_rate=1e-3)
assert gen_instance.optimizer.defaults["lr"] == 1e-3
assert gen_instance.optimizer.defaults["lr"] == pytest.approx(1e-3, 1e-4)
gen_instance._load_optimizer(learning_rate=1e-3, kwargs={"weight_decay": 0.015})
assert gen_instance.optimizer.defaults["weight_decay"] == 0.015
assert gen_instance.optimizer.defaults["weight_decay"] == pytest.approx(0.015, 1e-3)


def test_load_scheduler(gen_instance, train_test_dataset):
Expand Down Expand Up @@ -263,15 +263,15 @@ def test_train_evaluate(gen_instance, gen_instance_dm):
num_train_epochs,
learning_rate,
)
assert gen_instance.results["overall_precision"] == 0.0
assert gen_instance.results["overall_precision"] == pytest.approx(0.0, 1e-3)
assert (model_path / "pytorch_model.bin").is_file()
assert (model_path / "special_tokens_map.json").is_file()
assert (model_path / "config.json").is_file()
evaluate_result = gen_instance.evaluate("Python ist toll.")
assert evaluate_result[0]["score"]
del gen_instance._model_path
with pytest.raises(ValueError):
evaluate_result = gen_instance.evaluate("Python ist toll.")
gen_instance.evaluate("Python ist toll.")
# check that column names throw error if not given correctly
label_column_name = "something"
with pytest.raises(ValueError):
Expand Down
11 changes: 7 additions & 4 deletions moralization/transformers_model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(
self,
model_path: Union[str, Path],
model_name: str = "bert-base-cased",
label_names: List = ["0", "M", "M-BEG"],
label_names: List = None,
) -> None:
"""
Import an existing model from `model_name` from Hugging Face.
Expand All @@ -81,6 +81,8 @@ def __init__(
"""
super().__init__(model_path)
self.model_name = model_name
if label_names is None:
label_names = ["0", "M", "M-BEG"]
self._model_is_trained = False
self.metadata = _import_or_create_metadata(self.model_path)
# somewhere we should check that the label names length is same as number of different labels
Expand Down Expand Up @@ -471,8 +473,8 @@ def _initialize_training(
def train(
self,
data_manager: DataManager,
token_column_name: str,
label_column_name: str,
token_column_name: str = "Sentences",
label_column_name: str = "Labels",
num_train_epochs: int = 5,
learning_rate: float = 2e-5,
) -> None:
Expand Down Expand Up @@ -569,7 +571,8 @@ def evaluate(self, token: str):
)
return token_classifier(token)

def test(self):
def test(self, test_string: str, style: str = "span"):
# to be completed
pass

def _check_model_is_trained_before_it_can_be(self, action: str = "used"):
Expand Down

0 comments on commit dc30dbc

Please sign in to comment.