Skip to content

Commit

Permalink
Fix Predictor interface
Browse files Browse the repository at this point in the history
  • Loading branch information
tma15 committed Feb 4, 2024
1 parent 4fed6da commit c7e9d01
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 15 deletions.
13 changes: 8 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ bunruija-evaluate -y config.yaml
### data
You can set data-related settings in `data`.

```sh
```yaml
data:
train: train.csv # training data
dev: dev.csv # development data
Expand All @@ -108,21 +108,21 @@ When you set `label_column` to `label` and `text_column` to `text`, which are th

Format of `csv`:

```
```csv
label,text
label_name,sentence
```

Format of `json`:

```
```json
[{"label", "label_name", "text": "sentence"}]
```

Format of `jsonl`:

```
```json
{"label", "label_name", "text": "sentence"}
```

Expand All @@ -131,12 +131,15 @@ You can set pipeline of your model in `pipeline`


## Prediction using the trained classifier in Python code
After you trained a classification model, you can use that model for prediction as follows:
```python
from bunruija import Predictor
predictor = Predictor(args.yaml)
predictor = Predictor.from_pretrained("output_dir")
while True:
text = input("Input:")
label: list[str] = predictor([text], return_label_type="str")
print(label[0])
```

`output_dir` is a directory that is specified in `output_dir` in config.
2 changes: 1 addition & 1 deletion bunruija/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class Evaluator:
def __init__(self, args: Namespace):
self.config = BunruijaConfig.from_yaml(args.yaml)
self.verbose = args.verbose
self.predictor = Predictor(args.yaml)
self.predictor = Predictor.from_pretrained(self.config.output_dir)

def evaluate(self):
labels_test, X_test = load_data(
Expand Down
4 changes: 3 additions & 1 deletion bunruija/options.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
from pathlib import Path


def get_parser():
Expand All @@ -18,7 +19,8 @@ def get_default_train_parser():


def get_default_prediction_parser():
parser = get_parser()
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=Path, required=True)
return parser


Expand Down
6 changes: 5 additions & 1 deletion bunruija/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ def main(args):
parser = options.get_default_prediction_parser()
args = parser.parse_args(args)

predictor = Predictor(args.yaml)
predictor = Predictor.from_pretrained(args.model)
while True:
text = input("Input:")
label: list[str] = predictor([text], return_label_type="str")
Expand All @@ -16,3 +16,7 @@ def main(args):

def cli_main():
main(sys.argv[1:])


if __name__ == "__main__":
cli_main()
16 changes: 9 additions & 7 deletions bunruija/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,23 @@
import numpy as np
from sklearn.preprocessing import LabelEncoder # type: ignore

from . import BunruijaConfig


class Predictor:
"""Predicts labels"""

def __init__(self, config_file):
config = BunruijaConfig.from_yaml(config_file)
model_path: Path = config.output_dir / "model.bunruija"
def __init__(self, model_dir: str | Path):
if isinstance(model_dir, str):
model_dir = Path(model_dir)

model_path: Path = model_dir / "model.bunruija"
with open(model_path, "rb") as f:
model_data: dict = pickle.load(f)
self.model = model_data["pipeline"]
self.label_encoder: LabelEncoder = model_data["label_encoder"]

self.model = model_data["pipeline"]
self.label_encoder: LabelEncoder = model_data["label_encoder"]
@classmethod
def from_pretrained(cls, model_path: str | Path):
return cls(model_path)

def __call__(
self,
Expand Down

0 comments on commit c7e9d01

Please sign in to comment.