-
Notifications
You must be signed in to change notification settings - Fork 0
/
Q&A_CDQA_Finetuning.py
57 lines (39 loc) · 1.67 KB
/
Q&A_CDQA_Finetuning.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
#!/usr/bin/env python
# coding: utf-8
# In[ ]:
#Installing the essential libraries
get_ipython().system('pip install cdqa')
#Pandas has to be reinstalled since cdqa requires a different version of pandas compared to our requirement
conda update --force-reinstall pandas
import pandas as pd
import joblib
from ast import literal_eval
from cdqa.utils.filters import filter_paragraphs
from cdqa.utils.download import download_model, download_bnpp_data
from cdqa.pipeline.cdqa_sklearn import QAPipeline
#Downloading the model required for cdqa from the releases
download_model(model='bert-squad_1.1', dir='./models')
#This dataset comprises of the title and text of articles related to covid-19, retrieved from CORD 19
df = pd.read_csv('/kaggle/input/dataset.csv', converters={'paragraphs': literal_eval})
df = filter_paragraphs(df)
#QA pipeline created using the downloaded model as the reader
cdqa_pipeline = QAPipeline(reader='models/bert_qa.joblib')
#finetunes the model and fits the retriever to the dataset
cdqa_pipeline.fit_retriever(df)
#To check the accurate results by prediction.
query = 'How contagious is coronavirus?'
prediction = cdqa_pipeline.predict(query)
#printing the results
print('query: {}\n'.format(query))
print('answer: {}\n'.format(prediction[0]))
print('title: {}\n'.format(prediction[1]))
print('paragraph: {}\n'.format(prediction[2]))
#To save the model
cdqa_pipeline.dump_reader('models/bert_qa_finetune.joblib')
#To use the model again:
pip install joblib
#load model
finetune_x = joblib.load('bert_qa_finetune2.joblib')
#take query from user and predict using model, change variable names
query = 'What is coronavirus?'
prediction = finetune_x.predict(query)