diff --git a/README.md b/README.md index d153e8e..797989a 100755 --- a/README.md +++ b/README.md @@ -92,7 +92,7 @@ bunruija-evaluate -y config.yaml ### data You can set data-related settings in `data`. -```sh +```yaml data: train: train.csv # training data dev: dev.csv # development data @@ -108,7 +108,7 @@ When you set `label_column` to `label` and `text_column` to `text`, which are th Format of `csv`: -``` +```csv label,text label_name,sentence … @@ -116,13 +116,13 @@ label_name,sentence Format of `json`: -``` +```json [{"label", "label_name", "text": "sentence"}] ``` Format of `jsonl`: -``` +```json {"label", "label_name", "text": "sentence"} ``` @@ -131,12 +131,15 @@ You can set pipeline of your model in `pipeline` ## Prediction using the trained classifier in Python code +After you trained a classification model, you can use that model for prediction as follows: ```python from bunruija import Predictor -predictor = Predictor(args.yaml) +predictor = Predictor.from_pretrained("output_dir") while True: text = input("Input:") label: list[str] = predictor([text], return_label_type="str") print(label[0]) ``` + +`output_dir` is a directory that is specified in `output_dir` in config. diff --git a/bunruija/evaluator.py b/bunruija/evaluator.py index e78a1a0..fde21cd 100755 --- a/bunruija/evaluator.py +++ b/bunruija/evaluator.py @@ -14,7 +14,7 @@ class Evaluator: def __init__(self, args: Namespace): self.config = BunruijaConfig.from_yaml(args.yaml) self.verbose = args.verbose - self.predictor = Predictor(args.yaml) + self.predictor = Predictor.from_pretrained(self.config.output_dir) def evaluate(self): labels_test, X_test = load_data( diff --git a/bunruija/options.py b/bunruija/options.py index bf6da32..e906fc2 100755 --- a/bunruija/options.py +++ b/bunruija/options.py @@ -1,4 +1,5 @@ import argparse +from pathlib import Path def get_parser(): @@ -18,7 +19,8 @@ def get_default_train_parser(): def get_default_prediction_parser(): - parser = get_parser() + parser = argparse.ArgumentParser() + parser.add_argument("--model", type=Path, required=True) return parser diff --git a/bunruija/predict.py b/bunruija/predict.py index b633e5c..82cfab3 100755 --- a/bunruija/predict.py +++ b/bunruija/predict.py @@ -7,7 +7,7 @@ def main(args): parser = options.get_default_prediction_parser() args = parser.parse_args(args) - predictor = Predictor(args.yaml) + predictor = Predictor.from_pretrained(args.model) while True: text = input("Input:") label: list[str] = predictor([text], return_label_type="str") @@ -16,3 +16,7 @@ def main(args): def cli_main(): main(sys.argv[1:]) + + +if __name__ == "__main__": + cli_main() diff --git a/bunruija/predictor.py b/bunruija/predictor.py index 2419aec..ca8869a 100755 --- a/bunruija/predictor.py +++ b/bunruija/predictor.py @@ -4,21 +4,23 @@ import numpy as np from sklearn.preprocessing import LabelEncoder # type: ignore -from . import BunruijaConfig - class Predictor: """Predicts labels""" - def __init__(self, config_file): - config = BunruijaConfig.from_yaml(config_file) - model_path: Path = config.output_dir / "model.bunruija" + def __init__(self, model_dir: str | Path): + if isinstance(model_dir, str): + model_dir = Path(model_dir) + model_path: Path = model_dir / "model.bunruija" with open(model_path, "rb") as f: model_data: dict = pickle.load(f) + self.model = model_data["pipeline"] + self.label_encoder: LabelEncoder = model_data["label_encoder"] - self.model = model_data["pipeline"] - self.label_encoder: LabelEncoder = model_data["label_encoder"] + @classmethod + def from_pretrained(cls, model_path: str | Path): + return cls(model_path) def __call__( self,