-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #30 from nbcstevenchen/main
Create Semantic Ranking for Title, Summary, Keywords, Title and Conference Name Searching
- Loading branch information
Showing
6 changed files
with
42,893 additions
and
0 deletions.
There are no files selected for viewing
42,769 changes: 42,769 additions & 0 deletions
42,769
cncf-youtube-channel-summarizer/data/cncf_video_summary_combine.csv
Large diffs are not rendered by default.
Oops, something went wrong.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
from sentence_transformers import SentenceTransformer | ||
import pandas as pd | ||
import pickle | ||
def embedding_generator(model, data): | ||
dataset = pd.read_csv('data/cncf_video_summary_combine.csv') | ||
bi_encoder = SentenceTransformer(model) | ||
bi_encoder.max_seq_length = 256 #Truncate long passages to 256 tokens | ||
##### Semantic Search ##### | ||
# Encode the query using the bi-encoder and find potentially relevant passages | ||
embeddings = bi_encoder.encode(data, convert_to_tensor=True, show_progress_bar=True) | ||
return embeddings | ||
|
||
|
||
if __name__ == "__main__": | ||
dataset = pd.read_csv('data/cncf_video_summary_combine.csv') | ||
embeddings = embedding_generator('multi-qa-MiniLM-L6-cos-v1', dataset['merge']) | ||
with open('data/embedding.pkl', 'wb') as f: | ||
pickle.dump(embeddings.numpy(), f) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import pandas as pd | ||
import os | ||
|
||
def merge_csv(): | ||
csv_list = [] | ||
filename_list = os.listdir('data/') | ||
for i in range(0, len(filename_list)): | ||
if '.csv' in filename_list[i]: | ||
filename_list[i] = 'data/' + filename_list[i] | ||
csv_list.append(filename_list[i]) | ||
|
||
dataframes = [pd.read_csv(each_file) for each_file in csv_list] | ||
merged_df = pd.concat(dataframes, ignore_index=True) | ||
merged_df['merge'] = merged_df['video_title'] + ' ' + merged_df['conference_name'] + ' ' + merged_df[ | ||
'summary'] + ' ' + merged_df['keywords'] | ||
merged_df.to_csv('data/cncf_video_summary_combine.csv', index=False) | ||
|
||
if __name__ == "__main__": | ||
merge_csv() | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
import nltk | ||
from nltk.tokenize import word_tokenize | ||
from rank_bm25 import BM25Okapi | ||
import pandas as pd | ||
from sentence_transformers import SentenceTransformer, util | ||
import pickle | ||
# Download required NLTK data | ||
nltk.download('punkt') | ||
|
||
class BM25(): | ||
def __init__(self, dataset, top_k=5): | ||
self.dataset = dataset | ||
self.top_k = top_k | ||
self.tokenized_corpus = [self.preprocess_text(doc) for doc in dataset['merge']] | ||
# Function to preprocess and tokenize text | ||
def preprocess_text(self, text): | ||
return word_tokenize(text.lower()) | ||
|
||
|
||
# Function to perform a search query | ||
def search(self, query, bm25): | ||
tokenized_query = self.preprocess_text(query) | ||
scores = bm25.get_scores(tokenized_query) | ||
results = [] | ||
top_n_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:self.top_k] | ||
|
||
video_ids = [] | ||
for i in top_n_indices: | ||
results.append((self.dataset['merge'][i], scores[i])) | ||
video_ids.append(self.dataset.loc[i]['video_id']) | ||
print(results) | ||
return video_ids | ||
|
||
|
||
def run(self, query): | ||
# Initialize BM25 | ||
bm25 = BM25Okapi(self.tokenized_corpus) | ||
# Example query | ||
# query = "CNCF Webinars" | ||
video_ids = self.search(query, bm25, ) | ||
return video_ids | ||
|
||
class BIENCODER(): | ||
def __init__(self, dataset, embeddings, top_k=5): | ||
self.dataset = dataset | ||
self.embeddings = embeddings | ||
self.top_k = top_k | ||
self.bi_encoder = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1') | ||
self.bi_encoder.max_seq_length = 256 | ||
|
||
def search(self, query): | ||
print("Input question:", query) | ||
question_embedding = self.bi_encoder.encode(query, convert_to_tensor=True) | ||
# question_embedding = question_embedding.cuda() | ||
hits = util.semantic_search(question_embedding, self.embeddings, top_k=self.top_k) | ||
hits = hits[0] # Get the hits for the first query | ||
# print(hits) | ||
|
||
# Output of top-5 hits from bi-encoder | ||
print("\n-------------------------\n") | ||
print("Top-3 Bi-Encoder Retrieval hits") | ||
hits = sorted(hits, key=lambda x: x['score'], reverse=True) | ||
video_ids = [] | ||
for hit in hits: | ||
print("\t{:.3f}\t{}".format(hit['score'], self.dataset['merge'][hit['corpus_id']])) | ||
video_ids.append(self.dataset.loc[hit['corpus_id']]['video_id']) | ||
return video_ids | ||
|
||
if __name__ == "__main__": | ||
query = 'CNCF Webinars' ## input query | ||
dataset = pd.read_csv('data/cncf_video_summary_combine.csv') | ||
print('Method 1: BM25 alg for semantic search:') | ||
bm25_search = BM25(dataset, top_k=5) | ||
video_ids = bm25_search.run(query) | ||
print('here') | ||
print(video_ids) | ||
|
||
print('Method 2: Deep learning for semantic search:') | ||
with open('data/embedding.pkl', 'rb') as f: | ||
embeddings = pickle.load(f) | ||
video_ids = BIENCODER(dataset, embeddings).search(query) | ||
print(video_ids) |