diff --git a/application/nlq/business/vector_store.py b/application/nlq/business/vector_store.py index cd4941b7..18e6ef12 100644 --- a/application/nlq/business/vector_store.py +++ b/application/nlq/business/vector_store.py @@ -67,6 +67,9 @@ def get_all_agent_cot_samples(cls, profile_name): def add_sample(cls, profile_name, question, answer): logger.info(f'add sample question: {question} to profile {profile_name}') embedding = cls.create_vector_embedding_with_bedrock(question) + has_same_sample = cls.search_same_query(profile_name, 1, 'uba', embedding) + if has_same_sample: + logger.info(f'delete sample sample entity: {question} to profile {profile_name}') if cls.opensearch_dao.add_sample('uba', profile_name, question, answer, embedding): logger.info('Sample added') @@ -74,13 +77,19 @@ def add_sample(cls, profile_name, question, answer): def add_entity_sample(cls, profile_name, entity, comment): logger.info(f'add sample entity: {entity} to profile {profile_name}') embedding = cls.create_vector_embedding_with_bedrock(entity) + has_same_sample = cls.search_same_query(profile_name, 1, 'uba_ner', embedding) + if has_same_sample: + logger.info(f'delete sample sample entity: {entity} to profile {profile_name}') if cls.opensearch_dao.add_entity_sample('uba_ner', profile_name, entity, comment, embedding): logger.info('Sample added') @classmethod def add_agent_cot_sample(cls, profile_name, entity, comment): - logger.info(f'add sample entity: {entity} to profile {profile_name}') + logger.info(f'add agent sample query: {entity} to profile {profile_name}') embedding = cls.create_vector_embedding_with_bedrock(entity) + has_same_sample = cls.search_same_query(profile_name, 1, 'uba_agent', embedding) + if has_same_sample: + logger.info(f'delete agent sample sample query: {entity} to profile {profile_name}') if cls.opensearch_dao.add_agent_cot_sample('uba_agent', profile_name, entity, comment, embedding): logger.info('Sample added') @@ -124,3 +133,29 @@ def search_sample(cls, profile_name, top_k, index_name, query): logger.info(f'search sample question: {query} {index_name} from profile {profile_name}') sample_list = cls.opensearch_dao.search_sample(profile_name, top_k, index_name, query) return sample_list + + @classmethod + def search_sample_with_embedding(cls, profile_name, top_k, index_name, query_embedding): + sample_list = cls.opensearch_dao.search_sample_with_embedding(profile_name, top_k, index_name, query_embedding) + return sample_list + + @classmethod + def search_same_query(cls, profile_name, top_k, index_name, embedding): + search_res = cls.search_sample_with_embedding(profile_name, top_k, index_name, embedding) + if len(search_res) > 0: + similarity_sample = search_res[0] + similarity_score = similarity_sample["_score"] + similarity_id = similarity_sample['_id'] + if similarity_score == 1.0: + if index_name == "uba": + cls.delete_sample(profile_name, similarity_id) + return True + elif index_name == "uba_ner": + cls.delete_entity_sample(profile_name, similarity_id) + return True + elif index_name == "uba_agent": + cls.delete_agent_cot_sample(profile_name, similarity_id) + return True + else: + return False + return False diff --git a/application/nlq/data_access/opensearch.py b/application/nlq/data_access/opensearch.py index 9ccfcffd..69cbe080 100644 --- a/application/nlq/data_access/opensearch.py +++ b/application/nlq/data_access/opensearch.py @@ -191,6 +191,10 @@ def delete_sample(self, index_name, profile_name, doc_id): def search_sample(self, profile_name, top_k, index_name, query): records_with_embedding = create_vector_embedding_with_bedrock(query, index_name=index_name) + return self.search_sample_with_embedding(profile_name, top_k, index_name, records_with_embedding['vector_field']) + + + def search_sample_with_embedding(self, profile_name, top_k, index_name, query_embedding): search_query = { "size": top_k, # Adjust the size as needed to retrieve more or fewer results "query": { @@ -205,7 +209,7 @@ def search_sample(self, profile_name, top_k, index_name, query): "knn": { "vector_field": { # Make sure 'vector_field' is the name of your vector field in OpenSearch - "vector": records_with_embedding['vector_field'], + "vector": query_embedding, "k": top_k # Adjust k as needed to retrieve more or fewer nearest neighbors } }