Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade FAQ finance to Milvus 2.1 #3267

Merged
merged 4 commits into from
Sep 16, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
218 changes: 194 additions & 24 deletions applications/question_answering/faq_finance/README.md

Large diffs are not rendered by default.

31 changes: 19 additions & 12 deletions applications/question_answering/faq_finance/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from milvus import MetricType, IndexType
search_param = {'nprobe': 20}
collection_name = 'faq_finance'
partition_tag = 'partition_1'

MILVUS_HOST = '10.21.226.173'
MILVUS_HOST = '10.21.226.175'
MILVUS_PORT = 8530
data_dim = 256
top_k = 10
embedding_name = 'embeddings'

collection_param = {
'dimension': 256,
'index_file_size': 256,
'metric_type': MetricType.L2
index_config = {
"index_type": "IVF_FLAT",
"metric_type": "L2",
"params": {
"nlist": 1000
},
}

index_type = IndexType.IVF_FLAT
index_param = {'nlist': 1000}

top_k = 10
search_param = {'nprobe': 20}
search_params = {
"metric_type": "L2",
"params": {
"nprobe": top_k
},
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ class ErnieOp(Op):

def init_op(self):
from paddlenlp.transformers import AutoTokenizer
self.tokenizer = AutoTokenizer.from_pretrained('ernie-3.0-medium-zh')
self.tokenizer = AutoTokenizer.from_pretrained(
'rocketqa-zh-base-query-encoder')

def preprocess(self, input_dicts, data_id, log_id):
from paddlenlp.data import Stack, Tuple, Pad
Expand Down
13 changes: 8 additions & 5 deletions applications/question_answering/faq_finance/export_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,22 +30,25 @@
default='./checkpoint/model_50/model_state.pdparams', help="The path to model parameters to be loaded.")
parser.add_argument("--output_path", type=str, default='./output',
help="The path of model parameter in static graph to be saved.")
parser.add_argument('--model_name_or_path', default="rocketqa-zh-base-query-encoder", help="The pretrained model used for training")
parser.add_argument("--output_emb_size", default=256, type=int, help="Output_embedding_size, 0 means use hidden_size as output embedding size.")
args = parser.parse_args()
# yapf: enable

if __name__ == "__main__":
output_emb_size = 256

pretrained_model = AutoModel.from_pretrained("ernie-3.0-medium-zh")
pretrained_model = AutoModel.from_pretrained(args.model_name_or_path)

tokenizer = AutoTokenizer.from_pretrained('ernie-3.0-medium-zh')
model = SimCSE(pretrained_model, output_emb_size=output_emb_size)
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
model = SimCSE(pretrained_model, output_emb_size=args.output_emb_size)

if args.params_path and os.path.isfile(args.params_path):
state_dict = paddle.load(args.params_path)
model.set_dict(state_dict)
print("Loaded parameters from %s" % args.params_path)

else:
raise ValueError(
"Please set --params_path with correct pretrained model file")
model.eval()
# Convert to static graph with specific input description
model = paddle.jit.to_static(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

parser.add_argument("--corpus_file", type=str, required=True,
help="The corpus_file path.")

parser.add_argument('--model_name_or_path', default="rocketqa-zh-base-query-encoder", help="The pretrained model used for training")
parser.add_argument("--max_seq_length", default=64, type=int,
help="The maximum total input sequence length after tokenization. Sequences "
"longer than this will be truncated, sequences shorter will be padded.")
Expand Down Expand Up @@ -214,7 +214,7 @@ def read_text(file_path):
args.batch_size, args.use_tensorrt, args.precision,
args.cpu_threads, args.enable_mkldnn)

tokenizer = AutoTokenizer.from_pretrained('ernie-3.0-medium-zh')
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
id2corpus = read_text(args.corpus_file)

corpus_list = [{idx: text} for idx, text in id2corpus.items()]
Expand Down
117 changes: 117 additions & 0 deletions applications/question_answering/faq_finance/milvus_ann_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import random
from tqdm import tqdm
import time
import argparse

import numpy as np
from milvus_util import VecToMilvus, RecallByMilvus, text_max_len
from config import collection_name, partition_tag, embedding_name

parser = argparse.ArgumentParser()
parser.add_argument("--data_path",
default='data/corpus.csv',
type=str,
required=True,
help="The data for vector extraction.")
parser.add_argument("--embedding_path",
default='corpus_embedding.npy',
type=str,
required=True,
help="The vector path for data.")
parser.add_argument('--index',
default=0,
type=int,
help='index of the vector for search')
parser.add_argument('--insert',
action='store_true',
help='whether to insert data')
parser.add_argument('--search',
action='store_true',
help='whether to search data')
parser.add_argument('--batch_size',
default=100000,
type=int,
help='number of examples to insert each time')
args = parser.parse_args()


def read_text(file_path):
file = open(file_path)
id2corpus = []
for idx, data in enumerate(file.readlines()):
question, answer = data.strip().split('\t')
id2corpus.append({'question': question, 'answer': answer})
return id2corpus


def milvus_data_insert(data_path, embedding_path, batch_size):
corpus_list = read_text(data_path)
embeddings = np.load(embedding_path)
embedding_ids = [i for i in range(embeddings.shape[0])]
client = VecToMilvus()
client.drop_collection(collection_name)
data_size = len(embedding_ids)
for i in tqdm(range(0, data_size, batch_size)):
cur_end = i + batch_size
if (cur_end > data_size):
cur_end = data_size
batch_emb = embeddings[np.arange(i, cur_end)]
entities = [
[j for j in range(i, cur_end, 1)],
[
corpus_list[j]['question'][:text_max_len - 1]
for j in range(i, cur_end, 1)
],
[
corpus_list[j]['answer'][:text_max_len - 1]
for j in range(i, cur_end, 1)
],
batch_emb # field embeddings, supports numpy.ndarray and list
]
client.insert(collection_name=collection_name,
entities=entities,
index_name=embedding_name,
partition_tag=partition_tag)


def milvus_data_recall(embedding_path, index):
embeddings = np.load(embedding_path)
embedding_ids = [i for i in range(embeddings.shape[0])]
recall_client = RecallByMilvus()
if (index > len(embedding_ids)):
print('Index should not be larger than embedding szie')
return
embeddings = embeddings[np.arange(index, index + 1)]
time_start = time.time()
result = recall_client.search(embeddings,
embedding_name,
collection_name,
partition_names=[partition_tag],
output_fields=['pk', 'text'])
time_end = time.time()
sum_t = time_end - time_start
print('time cost', sum_t, 's')
for hits in result:
for hit in hits:
print(f"hit: {hit}, text field: {hit.entity.get('text')}")


if __name__ == "__main__":
if (args.insert):
milvus_data_insert(args.data_path, args.embedding_path, args.batch_size)
if (args.search):
milvus_data_recall(args.embedding_path, args.index)
Loading