forked from MLlab4CS/Astro-mT5
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
24 lines (19 loc) · 877 Bytes
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from flair.data import Sentence
from flair.models import SequenceTagger
from tqdm import tqdm
import json
from flair.data import Corpus
from flair.datasets import ColumnCorpus
from flair.embeddings import TransformerWordEmbeddings
from flair.models import SequenceTagger
from flair.trainers import ModelTrainer
columns = {1: 'text', 4: 'ner'}
corpus: Corpus = ColumnCorpus('/content/Dataset_tr_ts_valid/mt5_valid', columns,
train_file='train.txt',
dev_file='test.txt',
test_file='valid.txt'
)
# load the model you trained
model_mean = SequenceTagger.load('/content/mt5-large/best_ model_large.pt')
result_mean = model_mean.evaluate(corpus.test, gold_label_type='ner',mini_batch_size=4, out_path=f"/content/mt5-large/pred_valid.txt")
print(result_mean)