-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtagger.py
30 lines (21 loc) · 905 Bytes
/
tagger.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import logging
from flair.data import Sentence
from flair.models import SequenceTagger
logger = logging.getLogger(__name__)
class Tagger:
def __init__(self, model_path: str):
logger.info(f'Initializing flair tagger from path: {model_path}')
self._sequence_tagger = SequenceTagger.load(model_path)
logger.info(f'Successfully initialized tagger from path: {model_path}')
def __call__(self, text: str) -> str:
sentence = Sentence(text)
self._sequence_tagger.predict(sentence)
result = '\n'.join([str(e) for e in sentence.get_spans('ner')])
return result
if __name__ == '__main__':
tagger = SequenceTagger.load('flair/ner-english')
sentence = Sentence('George Washington went to Washington')
tagger.predict(sentence)
for entity in sentence.get_spans('ner'):
print(entity)
tagger.save('/tmp/my_ner_tagger.pt')