From e3b07a951ca6ef6ff613df9f3591a53df0216758 Mon Sep 17 00:00:00 2001 From: vichayturen <1073931273@qq.com> Date: Fri, 26 Jul 2024 00:41:19 +0800 Subject: [PATCH 01/11] feat(answer-synthesize): enhance output with empty result messages Provide clearer feedback to the llm when no paragraphs or subgraphs are related to the query by adding specific messages in the synthesis response. --- .../operators/index_op/build_semantic_index.py | 12 +++++++++--- .../operators/llm_op/answer_synthesize.py | 18 ++++++++++++------ 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_semantic_index.py b/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_semantic_index.py index c70c8663..da50a112 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_semantic_index.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_semantic_index.py @@ -17,6 +17,7 @@ import os +import json from typing import Any, Dict from hugegraph_llm.config import resource_path, settings @@ -32,9 +33,14 @@ def __init__(self, embedding: BaseEmbedding): self.embedding = embedding def run(self, context: Dict[str, Any]) -> Dict[str, Any]: - vids = [vertex["id"] for vertex in context["vertices"]] - if len(vids) > 0: - log.debug("Building vector index for %s vertices...", len(vids)) + if len(context["vertices"]) > 0: + log.debug("Building vector index for %s vertices...", len(context["vertices"])) + vids = [] + vids_embedding = [] + for vertex in context["vertices"]: + vertex_text = f"{vertex['label']}\n{vertex['properties']}" + vids_embedding.append(self.embedding.get_text_embedding(vertex_text)) + vids.append(vertex["id"]) vids_embedding = [self.embedding.get_text_embedding(vid) for vid in vids] log.debug("Vector index built for %s vertices.", len(vids)) if os.path.exists(self.index_file) and os.path.exists(self.content_file): diff --git a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py index 8bcf05e0..c67cba5e 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py @@ -84,13 +84,19 @@ def run(self, context: Dict[str, Any]) -> Dict[str, Any]: return {"answer": response} vector_result = context.get("vector_result", []) - vector_result_context = ("The following are paragraphs related to the query:\n" - + "\n".join([f"{i + 1}. {res}" - for i, res in enumerate(vector_result)])) + if len(vector_result) == 0: + vector_result_context = "There are no paragraphs related to the query." + else: + vector_result_context = ("The following are paragraphs related to the query:\n" + + "\n".join([f"{i + 1}. {res}" + for i, res in enumerate(vector_result)])) graph_result = context.get("graph_result", []) - graph_result_context = ("The following are subgraph related to the query:\n" - + "\n".join([f"{i + 1}. {res}" - for i, res in enumerate(graph_result)])) + if len(graph_result) == 0: + graph_result_context = "There are no subgraph related to the query." + else: + graph_result_context = ("The following are subgraph related to the query:\n" + + "\n".join([f"{i + 1}. {res}" + for i, res in enumerate(graph_result)])) verbose = context.get("verbose") or False From c09730abbb5c1c1bc50689c4d4202d34a5d159c0 Mon Sep 17 00:00:00 2001 From: vichayturen <1073931273@qq.com> Date: Tue, 30 Jul 2024 12:12:42 +0800 Subject: [PATCH 02/11] fix: Add knowledge of "island nodes" to subgraph query extraction using two-stage queries. --- .../operators/hugegraph_op/graph_rag_query.py | 29 ++++++++++++++----- .../operators/index_op/semantic_id_query.py | 2 +- .../operators/llm_op/answer_synthesize.py | 2 +- .../llm_op/property_graph_extract.py | 17 ++++++----- 4 files changed, 34 insertions(+), 16 deletions(-) diff --git a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py index b76637cd..0d6aed6e 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py @@ -24,6 +24,9 @@ class GraphRAGQuery: + VERTEX_GREMLIN_QUERY_TEMPL = ( + "g.V().hasId({keywords}).as('subj').toList()" + ) ID_RAG_GREMLIN_QUERY_TEMPL = ( "g.V().hasId({keywords}).as('subj')" ".repeat(" @@ -114,25 +117,29 @@ def run(self, context: Dict[str, Any]) -> Dict[str, Any]: if not use_id_to_match: keywords_str = ",".join("'" + kw + "'" for kw in keywords) - rag_gremlin_query_template = self.PROP_RAG_GREMLIN_QUERY_TEMPL - rag_gremlin_query = rag_gremlin_query_template.format( + rag_gremlin_query = self.PROP_RAG_GREMLIN_QUERY_TEMPL.format( prop=self._prop_to_match, keywords=keywords_str, max_deep=self._max_deep, max_items=self._max_items, edge_labels=edge_labels_str, ) + result: List[Any] = self._client.gremlin().exec(gremlin=rag_gremlin_query)["data"] + knowledge: Set[str] = self._format_knowledge_from_query_result(query_result=result) else: - rag_gremlin_query_template = self.ID_RAG_GREMLIN_QUERY_TEMPL - rag_gremlin_query = rag_gremlin_query_template.format( + rag_gremlin_query = self.VERTEX_GREMLIN_QUERY_TEMPL.format( + keywords=entrance_vids, + ) + result: List[Any] = self._client.gremlin().exec(gremlin=rag_gremlin_query)["data"] + knowledge: Set[str] = self._format_knowledge_from_vertex(query_result=result) + rag_gremlin_query = self.ID_RAG_GREMLIN_QUERY_TEMPL.format( keywords=entrance_vids, max_deep=self._max_deep, max_items=self._max_items, edge_labels=edge_labels_str, ) - - result: List[Any] = self._client.gremlin().exec(gremlin=rag_gremlin_query)["data"] - knowledge: Set[str] = self._format_knowledge_from_query_result(query_result=result) + result: List[Any] = self._client.gremlin().exec(gremlin=rag_gremlin_query)["data"] + knowledge.update(self._format_knowledge_from_query_result(query_result=result)) context["graph_result"] = list(knowledge) context["synthesize_context_head"] = ( @@ -149,6 +156,14 @@ def run(self, context: Dict[str, Any]) -> Dict[str, Any]: return context + def _format_knowledge_from_vertex(self, query_result: List[Any]): + knowledge = set() + for item in query_result: + props_str = ", ".join(f"{k}: {v}" for k, v in item["properties"].items()) + node_str = f"{item['id']}{{{props_str}}}" + knowledge.add(node_str) + return knowledge + def _format_knowledge_from_query_result( self, query_result: List[Any], diff --git a/hugegraph-llm/src/hugegraph_llm/operators/index_op/semantic_id_query.py b/hugegraph-llm/src/hugegraph_llm/operators/index_op/semantic_id_query.py index 8b4e9d40..69997393 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/index_op/semantic_id_query.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/index_op/semantic_id_query.py @@ -39,5 +39,5 @@ def run(self, context: Dict[str, Any]) -> Dict[str, Any]: results = self.vector_index.search(query_vector, top_k=1) if results: graph_query_entrance.append(results[0]) - context["entrance_vids"] = graph_query_entrance + context["entrance_vids"] = list(set(graph_query_entrance)) return context diff --git a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py index 8bcf05e0..96aa3f28 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py @@ -88,7 +88,7 @@ def run(self, context: Dict[str, Any]) -> Dict[str, Any]: + "\n".join([f"{i + 1}. {res}" for i, res in enumerate(vector_result)])) graph_result = context.get("graph_result", []) - graph_result_context = ("The following are subgraph related to the query:\n" + graph_result_context = ("The following are knowledge from HugeGraph related to the query:\n" + "\n".join([f"{i + 1}. {res}" for i, res in enumerate(graph_result)])) diff --git a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/property_graph_extract.py b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/property_graph_extract.py index be44c7a5..1b520e02 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/property_graph_extract.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/property_graph_extract.py @@ -53,8 +53,8 @@ vertexLabel的id生成策略为:id:primaryKey1!primaryKey2 Example -Input example: -text +# Input +## Text 道路交通事故认定书 鱼公交认字[2013]第00478号 天气:小雨 @@ -79,10 +79,10 @@ 二0一四年一月二日 -graph schema +## Graph schema {"vertexLabels":[{"id":3,"name":"法条","id_strategy":"PRIMARY_KEY","primary_keys":["法典名","法条索引"],"nullable_keys":["法章名","法条内容"],"properties":["法典名","法章名","法条索引","法条内容"]},{"id":7,"name":"事故","id_strategy":"PRIMARY_KEY","primary_keys":["事故认定书编号","事故认定书单位"],"nullable_keys":[],"properties":["事故发生时间","事故认定书编号","事故认定书单位"]},{"id":11,"name":"发生地点","id_strategy":"PRIMARY_KEY","primary_keys":["城市","所属路段"],"nullable_keys":["走向","材质","路面情况","道路状况"],"properties":["城市","走向","材质","路面情况","道路状况","所属路段"]},{"id":12,"name":"当事人","id_strategy":"PRIMARY_KEY","primary_keys":["身份证号"],"nullable_keys":["姓名","性别","年龄","民族","驾照"],"properties":["身份证号","姓名","性别","年龄","民族","驾照"]},{"id":13,"name":"车辆","id_strategy":"PRIMARY_KEY","primary_keys":["车辆牌照"],"nullable_keys":["行驶证所属人","保险公司","保险情况","车辆类型"],"properties":["车辆牌照","行驶证所属人","保险公司","保险情况","车辆类型"]},{"id":14,"name":"行为","id_strategy":"PRIMARY_KEY","primary_keys":["行为名称"],"nullable_keys":[],"properties":["行为名称"]}],"edgeLabels":[{"id":7,"name":"事故相关法条","source_label":"事故","target_label":"法条","sort_keys":[],"nullable_keys":[],"properties":[]},{"id":8,"name":"事故相关当事人","source_label":"事故","target_label":"当事人","sort_keys":[],"nullable_keys":["责任认定"],"properties":["责任认定"]},{"id":9,"name":"事故相关行为","source_label":"事故","target_label":"行为","sort_keys":[],"nullable_keys":[],"properties":[]},{"id":10,"name":"当事人相关行为","source_label":"当事人","target_label":"行为","sort_keys":[],"nullable_keys":[],"properties":[]},{"id":11,"name":"当事人相关车辆","source_label":"当事人","target_label":"车辆","sort_keys":[],"nullable_keys":[],"properties":[]},{"id":12,"name":"事故发生地点","source_label":"事故","target_label":"发生地点","sort_keys":[],"nullable_keys":[],"properties":[]}]} -Output example: +# Output [{"label":"事故","type":"vertex","properties":{"事故发生时间":"2013-11-24 18:09:00.000","事故认定书编号":"鱼公交认字[2013]第00478号","事故认定书单位":"道路交通事故认定书"}},{"label":"发生地点","type":"vertex","properties":{"城市":"山东省鱼台县","所属路段":"251省道清河菜市场路口","走向":"南北","材质":"沥青","路面情况":"平坦","道路状况":"视线一般"}},{"label":"当事人","type":"vertex","properties":{"身份证号":"370827197201032316","姓名":"张小虎","性别":"男","年龄":"1972-01-03","驾照":"C1E"}},{"label":"当事人","type":"vertex","properties":{"身份证号":"370827195203122316","姓名":"于海洋","性别":"男","年龄":"1952-03-12"}},{"label":"车辆","type":"vertex","properties":{"车辆牌照":"鲁H7Z886","行驶证所属人":"谢彪","保险公司":"中国人民产保险股份有限公司济宁市分公司","保险情况":"交通事故责任强制保险","车辆类型":"小型轿车"}},{"label":"行为","type":"vertex","properties":{"行为名称":"逃逸"}},{"label":"行为","type":"vertex","properties":{"行为名称":"酒后驾车"}},{"label":"行为","type":"vertex","properties":{"行为名称":"观察不够"}},{"label":"法条","type":"vertex","properties":{"法典名":"中华人民共和国道路交通安全法","法条索引":"第三十八条","法条内容":"车辆、行人应当按照交通信号通行;遇有交通警察现场指挥时,应当按照交通警察的指挥通行;在没有交通信号的道路上,应当在确保安全、畅通的原则下通行。"}},{"label":"法条","type":"vertex","properties":{"法典名":"中华人民共和国道路交通安全法","法条索引":"第二十二条","法条内容":"饮酒,服用国家管制的精神药品或者醉药品,或者患有妨碍安全驾驶杭动车的疾病,或者过度劳影响安全驾驶的,不得买驶机动车。"}},{"label":"事故相关法条","type":"edge","outV":"7:鱼公交认字[2013]第00478号!道路交通事故认定书","outVLabel":"事故","inV":"3:中华人民共和国道路交通安全法!第三十八条","inVLabel":"法条","properties":{}},{"label":"事故相关法条","type":"edge","outV":"7:鱼公交认字[2013]第00478号!道路交通事故认定书","outVLabel":"事故","inV":"3:中华人民共和国道路交通安全法!第二十二条","inVLabel":"法条","properties":{}},{"label":"事故相关当事人","type":"edge","outV":"7:鱼公交认字[2013]第00478号!道路交通事故认定书","outVLabel":"事故","inV":"12: 370827197201032316","inVLabel":"当事人","properties":{"责任认定":"全部责任"}},{"label":"事故相关当事人","type":"edge","outV":"7:鱼公交认字[2013]第00478号!道路交通事故认定书","outVLabel":"事故","inV":"12: 370827195203122316","inVLabel":"当事人","properties":{"责任认定":"无责任"}},{"label":"事故相关行为","type":"edge","outV":"7:鱼公交认字[2013]第00478号!道路交通事故认定书","outVLabel":"当事人","inV":"14:逃逸","inVLabel":"行为","properties":{}},{"label":"事故相关行为","type":"edge","outV":"7:鱼公交认字[2013]第00478号!道路交通事故认定书","outVLabel":"当事人","inV":"14:酒后驾车","inVLabel":"行为","properties":{}},{"label":"事故相关行为","type":"edge","outV":"7:鱼公交认字[2013]第00478号!道路交通事故认定书","outVLabel":"当事人","inV":"14:观察不够","inVLabel":"行为","properties":{}},{"label":"当事人相关行为","type":"edge","outV":"12:370827197201032316","outVLabel":"当事人","inV":"14:逃逸","inVLabel":"行为","properties":{}},{"label":"当事人相关行为","type":"edge","outV":"12:370827197201032316","outVLabel":"当事人","inV":"14:酒后驾车","inVLabel":"行为","properties":{}},{"label":"当事人相关行为","type":"edge","outV":"12:370827197201032316","outVLabel":"当事人","inV":"14:观察不够","inVLabel":"行为","properties":{}},{"label":"当事人相关车辆","type":"edge","outV":"12:370827197201032316","outVLabel":"当事人","inV":"13:鲁H7Z886","inVLabel":"车辆","properties":{}},{"label":"事故发生地点","type":"edge","outV":"7:鱼公交认字[2013]第00478号!道路交通事故认定书","outVLabel":"事故","inV":"11:山东省鱼台县!251省道清河菜市场路口","inVLabel":"发生地点","properties":{}}] """ @@ -91,10 +91,13 @@ def generate_extract_property_graph_prompt(text, schema=None) -> str: return f"""--- 请根据上面的完整指令, 尝试根据下面给定的 schema, 提取下面的文本, 只需要输出 json 结果: -## Text: +# Input +## Text {text} -## Graph schema: -{schema}""" +## Graph schema +{schema} + +# Output""" def split_text(text: str) -> List[str]: From ea62e23926e1532989e072ab8960f221faa5ba1a Mon Sep 17 00:00:00 2001 From: vichayturen <1073931273@qq.com> Date: Tue, 30 Jul 2024 12:47:55 +0800 Subject: [PATCH 03/11] fix: change prompt --- .../hugegraph_llm/operators/index_op/build_semantic_index.py | 1 - .../src/hugegraph_llm/operators/llm_op/answer_synthesize.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_semantic_index.py b/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_semantic_index.py index da50a112..a439d467 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_semantic_index.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_semantic_index.py @@ -17,7 +17,6 @@ import os -import json from typing import Any, Dict from hugegraph_llm.config import resource_path, settings diff --git a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py index c96483cb..acfbf3e8 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py @@ -94,7 +94,7 @@ def run(self, context: Dict[str, Any]) -> Dict[str, Any]: if len(graph_result) == 0: graph_result_context = "There are no knowledge from HugeGraph related to the query." else: - graph_result_context = ("The following are subgraph related to the query:\n" + graph_result_context = ("The following are knowledge from HugeGraph related to the query:\n" + "\n".join([f"{i + 1}. {res}" for i, res in enumerate(graph_result)])) From 442ed722c2458454b65d5767807593c5403ab09a Mon Sep 17 00:00:00 2001 From: vichayturen <1073931273@qq.com> Date: Mon, 5 Aug 2024 15:11:34 +0800 Subject: [PATCH 04/11] Add optional parameters for keyword matching vid --- .../src/hugegraph_llm/operators/graph_rag_task.py | 3 ++- .../hugegraph_llm/operators/index_op/semantic_id_query.py | 7 ++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py b/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py index e343fad4..e3dcd99e 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py @@ -58,10 +58,11 @@ def extract_keyword( ) return self - def match_keyword_to_id(self): + def match_keyword_to_id(self, topk_per_keyword: int = 1): self._operators.append( SemanticIdQuery( embedding=self._embedding, + topk_per_keyword=topk_per_keyword ) ) return self diff --git a/hugegraph-llm/src/hugegraph_llm/operators/index_op/semantic_id_query.py b/hugegraph-llm/src/hugegraph_llm/operators/index_op/semantic_id_query.py index 69997393..4eaef4e4 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/index_op/semantic_id_query.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/index_op/semantic_id_query.py @@ -25,19 +25,20 @@ class SemanticIdQuery: - def __init__(self, embedding: BaseEmbedding): + def __init__(self, embedding: BaseEmbedding, topk_per_keyword: int = 1): index_file = str(os.path.join(resource_path, settings.graph_name, "vid.faiss")) content_file = str(os.path.join(resource_path, settings.graph_name, "vid.pkl")) self.vector_index = VectorIndex.from_index_file(index_file, content_file) self.embedding = embedding + self._topk_per_keyword = topk_per_keyword def run(self, context: Dict[str, Any]) -> Dict[str, Any]: keywords = context["keywords"] graph_query_entrance = [] for keyword in keywords: query_vector = self.embedding.get_text_embedding(keyword) - results = self.vector_index.search(query_vector, top_k=1) + results = self.vector_index.search(query_vector, top_k=self._topk_per_keyword) if results: - graph_query_entrance.append(results[0]) + graph_query_entrance.extend(results[:self._topk_per_keyword]) context["entrance_vids"] = list(set(graph_query_entrance)) return context From e1722f7b57fdb70e126aa854c89f63ea3c35b92d Mon Sep 17 00:00:00 2001 From: vichayturen <1073931273@qq.com> Date: Tue, 6 Aug 2024 13:04:51 +0800 Subject: [PATCH 05/11] Add asynchronous methods to the four types of generation functions in the rag web demo --- .../hugegraph_llm/models/embeddings/base.py | 7 +++ .../hugegraph_llm/models/embeddings/ollama.py | 9 ++++ .../hugegraph_llm/models/embeddings/openai.py | 5 +++ .../models/embeddings/qianfan.py | 8 ++++ .../src/hugegraph_llm/models/llms/base.py | 8 ++++ .../src/hugegraph_llm/models/llms/ollama.py | 21 +++++++++ .../src/hugegraph_llm/models/llms/openai.py | 30 +++++++++++++ .../src/hugegraph_llm/models/llms/qianfan.py | 17 +++++++ .../operators/llm_op/answer_synthesize.py | 44 ++++++++++++------- 9 files changed, 134 insertions(+), 15 deletions(-) diff --git a/hugegraph-llm/src/hugegraph_llm/models/embeddings/base.py b/hugegraph-llm/src/hugegraph_llm/models/embeddings/base.py index 15bc4eae..2ea8786c 100644 --- a/hugegraph-llm/src/hugegraph_llm/models/embeddings/base.py +++ b/hugegraph-llm/src/hugegraph_llm/models/embeddings/base.py @@ -60,6 +60,13 @@ def get_text_embedding( ) -> List[float]: """Comment""" + @abstractmethod + async def async_get_text_embedding( + self, + text: str + ) -> List[float]: + """Comment""" + @staticmethod def similarity( embedding1: Union[List[float], np.ndarray], diff --git a/hugegraph-llm/src/hugegraph_llm/models/embeddings/ollama.py b/hugegraph-llm/src/hugegraph_llm/models/embeddings/ollama.py index f87502c8..81e11cc5 100644 --- a/hugegraph-llm/src/hugegraph_llm/models/embeddings/ollama.py +++ b/hugegraph-llm/src/hugegraph_llm/models/embeddings/ollama.py @@ -32,6 +32,7 @@ def __init__( ): self.model = model self.client = ollama.Client(host=f"http://{host}:{port}", **kwargs) + self.async_client = ollama.AsyncClient(host=f"http://{host}:{port}", **kwargs) self.embedding_dimension = None def get_text_embedding( @@ -40,3 +41,11 @@ def get_text_embedding( ) -> List[float]: """Comment""" return list(self.client.embeddings(model=self.model, prompt=text)["embedding"]) + + async def async_get_text_embedding( + self, + text: str + ) -> List[float]: + """Comment""" + response = await self.async_client.embeddings(model=self.model, prompt=text) + return list(response["embedding"]) diff --git a/hugegraph-llm/src/hugegraph_llm/models/embeddings/openai.py b/hugegraph-llm/src/hugegraph_llm/models/embeddings/openai.py index 267effac..2a092e75 100644 --- a/hugegraph-llm/src/hugegraph_llm/models/embeddings/openai.py +++ b/hugegraph-llm/src/hugegraph_llm/models/embeddings/openai.py @@ -38,3 +38,8 @@ def get_text_embedding(self, text: str) -> List[float]: """Comment""" response = self.client.create(input=text, model=self.embedding_model_name) return response.data[0].embedding + + async def async_get_text_embedding(self, text: str) -> List[float]: + """Comment""" + response = await self.client.acreate(input=text, model=self.embedding_model_name) + return response.data[0].embedding diff --git a/hugegraph-llm/src/hugegraph_llm/models/embeddings/qianfan.py b/hugegraph-llm/src/hugegraph_llm/models/embeddings/qianfan.py index c86a9209..2f41fe5d 100644 --- a/hugegraph-llm/src/hugegraph_llm/models/embeddings/qianfan.py +++ b/hugegraph-llm/src/hugegraph_llm/models/embeddings/qianfan.py @@ -49,3 +49,11 @@ def get_text_embedding(self, text: str) -> List[float]: texts=[text] ) return response["body"]["data"][0]["embedding"] + + async def async_get_text_embedding(self, text: str) -> List[float]: + """ Usage refer: https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlmokk9qn""" + response = await self.client.ado( + model=self.embedding_model_name, + texts=[text] + ) + return response["body"]["data"][0]["embedding"] diff --git a/hugegraph-llm/src/hugegraph_llm/models/llms/base.py b/hugegraph-llm/src/hugegraph_llm/models/llms/base.py index 0e051566..04c1c27d 100644 --- a/hugegraph-llm/src/hugegraph_llm/models/llms/base.py +++ b/hugegraph-llm/src/hugegraph_llm/models/llms/base.py @@ -31,6 +31,14 @@ def generate( ) -> str: """Comment""" + @abstractmethod + async def agenerate( + self, + messages: Optional[List[Dict[str, Any]]] = None, + prompt: Optional[str] = None, + ) -> str: + """Comment""" + @abstractmethod def generate_streaming( self, diff --git a/hugegraph-llm/src/hugegraph_llm/models/llms/ollama.py b/hugegraph-llm/src/hugegraph_llm/models/llms/ollama.py index f94e268b..59655991 100644 --- a/hugegraph-llm/src/hugegraph_llm/models/llms/ollama.py +++ b/hugegraph-llm/src/hugegraph_llm/models/llms/ollama.py @@ -29,6 +29,7 @@ class OllamaClient(BaseLLM): def __init__(self, model: str, host: str = "127.0.0.1", port: int = 11434, **kwargs): self.model = model self.client = ollama.Client(host=f"http://{host}:{port}", **kwargs) + self.async_client = ollama.AsyncClient(host=f"http://{host}:{port}", **kwargs) @retry(tries=3, delay=1) def generate( @@ -50,6 +51,26 @@ def generate( print(f"Retrying LLM call {e}") raise e + @retry(tries=3, delay=1) + async def agenerate( + self, + messages: Optional[List[Dict[str, Any]]] = None, + prompt: Optional[str] = None, + ) -> str: + """Comment""" + if messages is None: + assert prompt is not None, "Messages or prompt must be provided." + messages = [{"role": "user", "content": prompt}] + try: + response = await self.async_client.chat( + model=self.model, + messages=messages, + ) + return response["message"]["content"] + except Exception as e: + print(f"Retrying LLM call {e}") + raise e + def generate_streaming( self, messages: Optional[List[Dict[str, Any]]] = None, diff --git a/hugegraph-llm/src/hugegraph_llm/models/llms/openai.py b/hugegraph-llm/src/hugegraph_llm/models/llms/openai.py index 4f50e96a..36cb11d3 100644 --- a/hugegraph-llm/src/hugegraph_llm/models/llms/openai.py +++ b/hugegraph-llm/src/hugegraph_llm/models/llms/openai.py @@ -73,6 +73,36 @@ def generate( log.error("Retrying LLM call %s", e) raise e + @retry(tries=3, delay=1) + async def agenerate( + self, + messages: Optional[List[Dict[str, Any]]] = None, + prompt: Optional[str] = None, + ) -> str: + """Generate a response to the query messages/prompt.""" + if messages is None: + assert prompt is not None, "Messages or prompt must be provided." + messages = [{"role": "user", "content": prompt}] + try: + completions = await openai.ChatCompletion.acreate( + model=self.model, + temperature=self.temperature, + max_tokens=self.max_tokens, + messages=messages, + ) + return completions.choices[0].message.content + # catch context length / do not retry + except openai.error.InvalidRequestError as e: + log.critical("Fatal: %s", e) + return str(f"Error: {e}") + # catch authorization errors / do not retry + except openai.error.AuthenticationError: + log.critical("The provided OpenAI API key is invalid") + return "Error: The provided OpenAI API key is invalid" + except Exception as e: + log.error("Retrying LLM call %s", e) + raise e + def generate_streaming( self, messages: Optional[List[Dict[str, Any]]] = None, diff --git a/hugegraph-llm/src/hugegraph_llm/models/llms/qianfan.py b/hugegraph-llm/src/hugegraph_llm/models/llms/qianfan.py index bebfa1ae..25d5e212 100644 --- a/hugegraph-llm/src/hugegraph_llm/models/llms/qianfan.py +++ b/hugegraph-llm/src/hugegraph_llm/models/llms/qianfan.py @@ -49,6 +49,23 @@ def generate( ) return response.body["result"] + @retry(tries=3, delay=1) + async def agenerate( + self, + messages: Optional[List[Dict[str, Any]]] = None, + prompt: Optional[str] = None, + ) -> str: + if messages is None: + assert prompt is not None, "Messages or prompt must be provided." + messages = [{"role": "user", "content": prompt}] + + response = await self.chat_comp.ado(model=self.chat_model, messages=messages) + if response.code != 200: + raise Exception( + f"Request failed with code {response.code}, message: {response.body['error_msg']}" + ) + return response.body["result"] + def generate_streaming( self, messages: Optional[List[Dict[str, Any]]] = None, diff --git a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py index acfbf3e8..27814255 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py @@ -16,6 +16,7 @@ # under the License. +import asyncio from typing import Any, Dict, Optional from hugegraph_llm.models.llms.base import BaseLLM @@ -97,15 +98,18 @@ def run(self, context: Dict[str, Any]) -> Dict[str, Any]: graph_result_context = ("The following are knowledge from HugeGraph related to the query:\n" + "\n".join([f"{i + 1}. {res}" for i, res in enumerate(graph_result)])) + context = asyncio.run(self.async_generate(context, context_head_str, context_tail_str, + vector_result_context, graph_result_context)) - verbose = context.get("verbose") or False + return context + async def async_generate(self, context: Dict[str, Any], context_head_str: str, context_tail_str: str, + vector_result_context: str, graph_result_context: str): + verbose = context.get("verbose") or False + task_cache = {} if self._raw_answer: prompt = self._question - response = self._llm.generate(prompt=prompt) - context["raw_answer"] = response - if verbose: - print(f"\033[91mANSWER: {response}\033[0m") + task_cache["raw_task"] = asyncio.create_task(self._llm.agenerate(prompt=prompt)) if self._vector_only_answer: context_str = (f"{context_head_str}\n" f"{vector_result_context}\n" @@ -115,10 +119,7 @@ def run(self, context: Dict[str, Any]) -> Dict[str, Any]: context_str=context_str, query_str=self._question, ) - response = self._llm.generate(prompt=prompt) - context["vector_only_answer"] = response - if verbose: - print(f"\033[91mANSWER: {response}\033[0m") + task_cache["vector_only_task"] = asyncio.create_task(self._llm.agenerate(prompt=prompt)) if self._graph_only_answer: context_str = (f"{context_head_str}\n" f"{graph_result_context}\n" @@ -128,10 +129,7 @@ def run(self, context: Dict[str, Any]) -> Dict[str, Any]: context_str=context_str, query_str=self._question, ) - response = self._llm.generate(prompt=prompt) - context["graph_only_answer"] = response - if verbose: - print(f"\033[91mANSWER: {response}\033[0m") + task_cache["graph_only_task"] = asyncio.create_task(self._llm.agenerate(prompt=prompt)) if self._graph_vector_answer: context_body_str = f"{vector_result_context}\n{graph_result_context}" context_str = (f"{context_head_str}\n" @@ -142,9 +140,25 @@ def run(self, context: Dict[str, Any]) -> Dict[str, Any]: context_str=context_str, query_str=self._question, ) - response = self._llm.generate(prompt=prompt) + task_cache["graph_vector_task"] = asyncio.create_task(self._llm.agenerate(prompt=prompt)) + if task_cache.get("raw_task"): + response = await task_cache["raw_task"] + context["raw_answer"] = response + if verbose: + print(f"\033[91mANSWER: {response}\033[0m") + if task_cache.get("vector_only_task"): + response = await task_cache["vector_only_task"] + context["vector_only_answer"] = response + if verbose: + print(f"\033[91mANSWER: {response}\033[0m") + if task_cache.get("graph_only_task"): + response = await task_cache["graph_only_task"] + context["graph_only_answer"] = response + if verbose: + print(f"\033[91mANSWER: {response}\033[0m") + if task_cache.get("graph_vector_task"): + response = await task_cache["graph_vector_task"] context["graph_vector_answer"] = response if verbose: print(f"\033[91mANSWER: {response}\033[0m") - return context From f0deace6a78313df7a0c88f6153544caeb2425e7 Mon Sep 17 00:00:00 2001 From: vichayturen <1073931273@qq.com> Date: Tue, 6 Aug 2024 13:28:03 +0800 Subject: [PATCH 06/11] fix code style --- .../operators/common_op/check_schema.py | 3 ++- .../operators/llm_op/answer_synthesize.py | 16 ++++++++++------ .../operators/llm_op/disambiguate_data.py | 4 +++- .../operators/llm_op/info_extract.py | 3 ++- .../operators/llm_op/property_graph_extract.py | 3 ++- 5 files changed, 19 insertions(+), 10 deletions(-) diff --git a/hugegraph-llm/src/hugegraph_llm/operators/common_op/check_schema.py b/hugegraph-llm/src/hugegraph_llm/operators/common_op/check_schema.py index 616c7b1c..7a1f64ad 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/common_op/check_schema.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/common_op/check_schema.py @@ -31,7 +31,8 @@ def run(self, schema=None) -> Any: # pylint: disable=too-many-branches raise ValueError("Input data is not a dictionary.") if "vertexlabels" not in schema or "edgelabels" not in schema: raise ValueError("Input data does not contain 'vertexlabels' or 'edgelabels'.") - if not isinstance(schema["vertexlabels"], list) or not isinstance(schema["edgelabels"], list): + if not isinstance(schema["vertexlabels"], list) or not isinstance(schema["edgelabels"], + list): raise ValueError("'vertexlabels' or 'edgelabels' in input data is not a list.") for vertex in schema["vertexlabels"]: if not isinstance(vertex, dict): diff --git a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py index 27814255..f885182b 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py @@ -95,16 +95,18 @@ def run(self, context: Dict[str, Any]) -> Dict[str, Any]: if len(graph_result) == 0: graph_result_context = "There are no knowledge from HugeGraph related to the query." else: - graph_result_context = ("The following are knowledge from HugeGraph related to the query:\n" - + "\n".join([f"{i + 1}. {res}" - for i, res in enumerate(graph_result)])) + graph_result_context = ( + "The following are knowledge from HugeGraph related to the query:\n" + + "\n".join([f"{i + 1}. {res}" + for i, res in enumerate(graph_result)])) context = asyncio.run(self.async_generate(context, context_head_str, context_tail_str, vector_result_context, graph_result_context)) return context - async def async_generate(self, context: Dict[str, Any], context_head_str: str, context_tail_str: str, - vector_result_context: str, graph_result_context: str): + async def async_generate(self, context: Dict[str, Any], context_head_str: str, + context_tail_str: str, vector_result_context: str, + graph_result_context: str): verbose = context.get("verbose") or False task_cache = {} if self._raw_answer: @@ -140,7 +142,9 @@ async def async_generate(self, context: Dict[str, Any], context_head_str: str, c context_str=context_str, query_str=self._question, ) - task_cache["graph_vector_task"] = asyncio.create_task(self._llm.agenerate(prompt=prompt)) + task_cache["graph_vector_task"] = asyncio.create_task( + self._llm.agenerate(prompt=prompt) + ) if task_cache.get("raw_task"): response = await task_cache["raw_task"] context["raw_answer"] = response diff --git a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/disambiguate_data.py b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/disambiguate_data.py index e34b6375..44bb69d1 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/disambiguate_data.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/disambiguate_data.py @@ -52,5 +52,7 @@ def run(self, data: Dict) -> Dict[str, List[Any]]: llm_output = self.llm.generate(prompt=prompt) data["triples"] = [] extract_triples_by_regex(llm_output, data) - print(f"LLM {self.__class__.__name__} input:{prompt} \n output: {llm_output} \n data: {data}") + print( + f"LLM {self.__class__.__name__} input:{prompt} \n" + f" output: {llm_output} \n data: {data}") return data diff --git a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/info_extract.py b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/info_extract.py index 6f0e3afb..1424143b 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/info_extract.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/info_extract.py @@ -152,7 +152,8 @@ def run(self, context: Dict[str, Any]) -> Dict[str, List[Any]]: for sentence in chunks: proceeded_chunk = self.extract_triples_by_llm(schema, sentence) - log.debug("[LLM] %s input: %s \n output:%s", self.__class__.__name__, sentence, proceeded_chunk) + log.debug("[LLM] %s input: %s \n output:%s", self.__class__.__name__, + sentence, proceeded_chunk) if schema: extract_triples_by_regex_with_schema(schema, proceeded_chunk, context) else: diff --git a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/property_graph_extract.py b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/property_graph_extract.py index 459e4a54..171b48a6 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/property_graph_extract.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/property_graph_extract.py @@ -101,7 +101,8 @@ def run(self, context: Dict[str, Any]) -> Dict[str, List[Any]]: items = [] for chunk in chunks: proceeded_chunk = self.extract_property_graph_by_llm(schema, chunk) - log.debug("[LLM] %s input: %s \n output:%s", self.__class__.__name__, chunk, proceeded_chunk) + log.debug("[LLM] %s input: %s \n output:%s", self.__class__.__name__, chunk, + proceeded_chunk) items.extend(self._extract_and_filter_label(schema, proceeded_chunk)) items = self.filter_item(schema, items) for item in items: From 34184bd27c3560a9acef167ee7bb5489d729a2e6 Mon Sep 17 00:00:00 2001 From: vichayturen <1073931273@qq.com> Date: Fri, 9 Aug 2024 18:07:37 +0800 Subject: [PATCH 07/11] Add todo --- .../hugegraph_llm/operators/common_op/merge_dedup_rerank.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/hugegraph-llm/src/hugegraph_llm/operators/common_op/merge_dedup_rerank.py b/hugegraph-llm/src/hugegraph_llm/operators/common_op/merge_dedup_rerank.py index a65a7098..af6ce6a6 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/common_op/merge_dedup_rerank.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/common_op/merge_dedup_rerank.py @@ -35,6 +35,10 @@ def __init__(self, embedding: BaseEmbedding, topk: int = 10): self.topk = topk def run(self, context: Dict[str, Any]) -> Dict[str, Any]: + # TODO: 逻辑应该是: + # 1. 分词后优先关键词匹配 vid (直接匹配) + # 2. 匹配不到,尝试退化模糊vid召回(并提示用户是否想问的是xxx) + # 3. 之后我们可以把chunk/graph加上,查询的时候可以同时查询关键词对应的chunk,然后一起综合(进阶) query = context.get("query") vector_result = context.get("vector_result", []) From a46112f07e6f6d7a1a11345b8f977baa27e7b42b Mon Sep 17 00:00:00 2001 From: imbajin Date: Sun, 11 Aug 2024 13:01:53 +0800 Subject: [PATCH 08/11] chore: mark some todos --- .../src/hugegraph_llm/demo/rag_web_demo.py | 8 +- .../src/hugegraph_llm/enums/id_strategy.py | 1 + .../src/hugegraph_llm/indices/graph_index.py | 4 +- .../operators/hugegraph_op/graph_rag_query.py | 90 +++++++++---------- 4 files changed, 55 insertions(+), 48 deletions(-) diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py b/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py index 08a0d2de..f82ea0a9 100644 --- a/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py +++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py @@ -50,10 +50,12 @@ def convert_bool_str(string): raise gr.Error(f"Invalid boolean string: {string}") +# TODO: enhance/distinguish the "graph_rag" name to avoid confusion def graph_rag(text: str, raw_answer: str, vector_only_answer: str, graph_only_answer: str, graph_vector_answer): vector_search = convert_bool_str(vector_only_answer) or convert_bool_str(graph_vector_answer) graph_search = convert_bool_str(graph_only_answer) or convert_bool_str(graph_vector_answer) + if raw_answer == "false" and not vector_search and not graph_search: gr.Warning("Please select at least one generate mode.") return "", "", "", "" @@ -68,6 +70,7 @@ def graph_rag(text: str, raw_answer: str, vector_only_answer: str, graph_only_answer=convert_bool_str(graph_only_answer), graph_vector_answer=convert_bool_str(graph_vector_answer) ).run(verbose=True, query=text) + try: context = searcher.run(verbose=True, query=text) return ( @@ -76,9 +79,12 @@ def graph_rag(text: str, raw_answer: str, vector_only_answer: str, context.get("graph_only_answer", ""), context.get("graph_vector_answer", "") ) - except Exception as e: # pylint: disable=broad-exception-caught + except ValueError as e: log.error(e) raise gr.Error(str(e)) + except Exception as e: # pylint: disable=broad-exception-caught + log.error(e) + raise gr.Error(f"An unexpected error occurred: {str(e)}") def build_kg(file, schema, example_prompt, build_mode): # pylint: disable=too-many-branches diff --git a/hugegraph-llm/src/hugegraph_llm/enums/id_strategy.py b/hugegraph-llm/src/hugegraph_llm/enums/id_strategy.py index 8db43743..5f3cadb0 100644 --- a/hugegraph-llm/src/hugegraph_llm/enums/id_strategy.py +++ b/hugegraph-llm/src/hugegraph_llm/enums/id_strategy.py @@ -19,6 +19,7 @@ from enum import Enum +# Note: we don't support the "UUID" strategy for now class IdStrategy(Enum): AUTOMATIC = "AUTOMATIC" CUSTOMIZE_NUMBER = "CUSTOMIZE_NUMBER" diff --git a/hugegraph-llm/src/hugegraph_llm/indices/graph_index.py b/hugegraph-llm/src/hugegraph_llm/indices/graph_index.py index 39612044..74269fcb 100644 --- a/hugegraph-llm/src/hugegraph_llm/indices/graph_index.py +++ b/hugegraph-llm/src/hugegraph_llm/indices/graph_index.py @@ -36,12 +36,12 @@ def __init__( def clear_graph(self): self.client.gremlin().exec("g.V().drop()") + # TODO: replace triples with a more specific graph element type & implement it def add_triples(self, triples: list): - # TODO pass + # TODO: replace triples with a more specific graph element type & implement it def search_triples(self, max_deep: int = 2): - # TODO pass def execute_gremlin_query(self, query: str): diff --git a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py index 0d6aed6e..2c81ee00 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py @@ -27,45 +27,47 @@ class GraphRAGQuery: VERTEX_GREMLIN_QUERY_TEMPL = ( "g.V().hasId({keywords}).as('subj').toList()" ) - ID_RAG_GREMLIN_QUERY_TEMPL = ( - "g.V().hasId({keywords}).as('subj')" - ".repeat(" - " bothE({edge_labels}).as('rel').otherV().as('obj')" - ").times({max_deep})" - ".path()" - ".by(project('label', 'id', 'props')" - " .by(label())" - " .by(id())" - " .by(valueMap().by(unfold()))" - ")" - ".by(project('label', 'inV', 'outV', 'props')" - " .by(label())" - " .by(inV().id())" - " .by(outV().id())" - " .by(valueMap().by(unfold()))" - ")" - ".limit({max_items})" - ".toList()" + # TODO: we could use a simpler query (like kneighbor-api to get the edges) + ID_RAG_GREMLIN_QUERY_TEMPL = """ + g.V().hasId({keywords}).as('subj') + .repeat( + bothE({edge_labels}).as('rel').otherV().as('obj') + ).times({max_deep}) + .path() + .by(project('label', 'id', 'props') + .by(label()) + .by(id()) + .by(valueMap().by(unfold())) ) - PROP_RAG_GREMLIN_QUERY_TEMPL = ( - "g.V().has('{prop}', within({keywords})).as('subj')" - ".repeat(" - " bothE({edge_labels}).as('rel').otherV().as('obj')" - ").times({max_deep})" - ".path()" - ".by(project('label', 'props')" - " .by(label())" - " .by(valueMap().by(unfold()))" - ")" - ".by(project('label', 'inV', 'outV', 'props')" - " .by(label())" - " .by(inV().values('{prop}'))" - " .by(outV().values('{prop}'))" - " .by(valueMap().by(unfold()))" - ")" - ".limit({max_items})" - ".toList()" + .by(project('label', 'inV', 'outV', 'props') + .by(label()) + .by(inV().id()) + .by(outV().id()) + .by(valueMap().by(unfold())) ) + .limit({max_items}) + .toList() + """ + + PROP_RAG_GREMLIN_QUERY_TEMPL = """ + g.V().has('{prop}', within({keywords})).as('subj') + .repeat( + bothE({edge_labels}).as('rel').otherV().as('obj') + ).times({max_deep}) + .path() + .by(project('label', 'props') + .by(label()) + .by(valueMap().by(unfold())) + ) + .by(project('label', 'inV', 'outV', 'props') + .by(label()) + .by(inV().values('{prop}')) + .by(outV().values('{prop}')) + .by(valueMap().by(unfold())) + ) + .limit({max_items}) + .toList() + """ def __init__( self, @@ -96,10 +98,10 @@ def run(self, context: Dict[str, Any]) -> Dict[str, Any]: user = context.get("user") or "admin" pwd = context.get("pwd") or "admin" self._client = PyHugeClient(ip=ip, port=port, graph=graph, user=user, pwd=pwd) - assert self._client is not None, "No graph for query." + assert self._client is not None, "No valid graph to search." keywords = context.get("keywords") - assert keywords is not None, "No keywords for query." + assert keywords is not None, "No keywords for graph query." entrance_vids = context.get("entrance_vids") assert entrance_vids is not None, "No entrance vertices for query." @@ -149,14 +151,15 @@ def run(self, context: Dict[str, Any]) -> Dict[str, Any]: "extracted based on key entities as subject:" ) + # TODO: replace print to log verbose = context.get("verbose") or False if verbose: - print("\033[93mKNOWLEDGE FROM GRAPH:") + print("\033[93mKnowledge from Graph:") print("\n".join(rel for rel in context["graph_result"]) + "\033[0m") return context - def _format_knowledge_from_vertex(self, query_result: List[Any]): + def _format_knowledge_from_vertex(self, query_result: List[Any]) -> Set[str]: knowledge = set() for item in query_result: props_str = ", ".join(f"{k}: {v}" for k, v in item["properties"].items()) @@ -164,10 +167,7 @@ def _format_knowledge_from_vertex(self, query_result: List[Any]): knowledge.add(node_str) return knowledge - def _format_knowledge_from_query_result( - self, - query_result: List[Any], - ) -> Set[str]: + def _format_knowledge_from_query_result(self, query_result: List[Any]) -> Set[str]: use_id_to_match = self._prop_to_match is None knowledge = set() for line in query_result: From 831f5f09b0c02d8d8bbd47be6e162c9e608a9a2b Mon Sep 17 00:00:00 2001 From: imbajin Date: Sun, 11 Aug 2024 23:34:42 +0800 Subject: [PATCH 09/11] refact: change a string of option_name & add format --- .../src/hugegraph_llm/demo/rag_web_demo.py | 84 ++++++++++--------- .../hugegraph_llm/utils/hugegraph_utils.py | 6 +- 2 files changed, 46 insertions(+), 44 deletions(-) diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py b/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py index f82ea0a9..01bc85c4 100644 --- a/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py +++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py @@ -99,10 +99,11 @@ def build_kg(file, schema, example_prompt, build_mode): # pylint: disable=too-m text += para.text text += "\n" elif full_path.endswith(".pdf"): - raise gr.Error("PDF will be supported later!") + raise gr.Error("PDF will be supported later! Try to upload text/docx now") else: raise gr.Error("Please input txt or docx file.") builder = KgBuilder(LLMs().get_llm(), Embeddings().get_embedding(), get_hg_client()) + if build_mode != "Rebuild vertex index": if schema: try: @@ -114,19 +115,22 @@ def build_kg(file, schema, example_prompt, build_mode): # pylint: disable=too-m else: return "ERROR: please input schema." builder.chunk_split(text, "paragraph", "zh") - if build_mode == "Rebuild vertex index": + + # TODO: avoid hardcoding the "build_mode" strings (use var/constant instead) + if build_mode == "Rebuild Vector": builder.fetch_graph_data() else: builder.extract_info(example_prompt, "property_graph") - if build_mode != "Test": - if build_mode in ("Clear and import", "Rebuild vertex index"): + # "Test Mode", "Import Mode", "Clear and Import", "Rebuild Vector" + if build_mode != "Test Mode": + if build_mode in ("Clear and Import", "Rebuild Vector"): clean_vector_index() builder.build_vector_index() - if build_mode == "Clear and import": + if build_mode == "Clear and Import": clean_hg_data() - if build_mode in ("Clear and import", "Import"): + if build_mode in ("Clear and Import", "Import Mode"): builder.commit_to_hugegraph() - if build_mode != "Test": + if build_mode != "Test Mode": builder.build_vertex_id_semantic_index() log.debug(builder.operators) try: @@ -325,22 +329,18 @@ def apply_embedding_configuration(arg1, arg2, arg3): gr.Markdown( - """## 1. build knowledge graph + """## 1. Build vector/graph RAG (💡) - Document: Input document file which should be TXT or DOCX. - Schema: Accepts two types of text as below: - User-defined JSON format Schema. - - Specify the name of the HugeGraph graph instance, and it will - automatically extract the schema of the graph. + - Specify the name of the HugeGraph graph instance, it will automatically get the schema from it. - Info extract head: The head of prompt of info extracting. - Build mode: - - Test: Only extract vertices and edges from file without building vector index or - importing into HugeGraph. - - Clear and Import: Clear the vector index and data of HugeGraph and then extract and - import new data. - - Import: Extract the data and append it to HugeGraph and vector index without clearing - anything. - - Rebuild vertex index: Do not clear the HugeGraph data, but only clear vector index - and build new one. + - Test Mode: Only extract vertices and edges from the file into memory (without building the vector index or + writing data into HugeGraph) + - Import Mode: Extract the data and append it to HugeGraph & the vector index (without clearing any existing data) + - Clear and Import: Clear all existed RAG data(vector + graph), then rebuild them from the current input + - Rebuild Vector: Only rebuild vector index. (keep the graph data intact) """ ) @@ -386,10 +386,9 @@ def apply_embedding_configuration(arg1, arg2, arg3): info_extract_template = gr.Textbox(value=SCHEMA_EXAMPLE_PROMPT, label="Info extract head") with gr.Column(): - mode = gr.Radio(choices=["Test", "Clear and import", "Import", - "Rebuild vertex index"], - value="Test", label="Build mode") - btn = gr.Button("Build knowledge graph") + mode = gr.Radio(choices=["Test Mode", "Import Mode", "Clear and Import", "Rebuild Vector"], + value="Test Mode", label="Build mode") + btn = gr.Button("Build Vector/Graph RAG") with gr.Row(): out = gr.Textbox(label="Output", show_copy_button=True) btn.click( # pylint: disable=no-member @@ -398,40 +397,43 @@ def apply_embedding_configuration(arg1, arg2, arg3): outputs=out ) - gr.Markdown("""## 2. Retrieval augmented generation by hugegraph""") + gr.Markdown("""## 2. RAG with HugeGraph 📖""") with gr.Row(): with gr.Column(scale=2): - inp = gr.Textbox(value="Tell me about Sarah.", label="Question") - raw_out = gr.Textbox(label="Raw LLM Answer", show_copy_button=True) - vector_only_out = gr.Textbox(label="Vector-only answer", show_copy_button=True) - graph_only_out = gr.Textbox(label="Graph-only answer", show_copy_button=True) - graph_vector_out = gr.Textbox(label="Graph-Vector answer", show_copy_button=True) + inp = gr.Textbox(value="Tell me about Sarah.", label="Question", show_copy_button=True) + raw_out = gr.Textbox(label="Basic LLM Answer", show_copy_button=True) + vector_only_out = gr.Textbox(label="Vector-only Answer", show_copy_button=True) + graph_only_out = gr.Textbox(label="Graph-only Answer", show_copy_button=True) + graph_vector_out = gr.Textbox(label="Graph-Vector Answer", show_copy_button=True) with gr.Column(scale=1): raw_radio = gr.Radio(choices=["true", "false"], value="false", - label="Raw LLM answer") + label="Basic LLM Answer") vector_only_radio = gr.Radio(choices=["true", "false"], value="true", - label="Vector-only answer") + label="Vector-only Answer") graph_only_radio = gr.Radio(choices=["true", "false"], value="false", - label="Graph-only answer") + label="Graph-only Answer") graph_vector_radio = gr.Radio(choices=["true", "false"], value="false", - label="Graph-Vector answer") - btn = gr.Button("Retrieval augmented generation") - btn.click(fn=graph_rag, inputs=[inp, raw_radio, vector_only_radio, graph_only_radio, # pylint: disable=no-member + label="Graph-Vector Answer") + btn = gr.Button("Answer Question") + btn.click(fn=graph_rag, inputs=[inp, raw_radio, vector_only_radio, graph_only_radio, # pylint: disable=no-member graph_vector_radio], outputs=[raw_out, vector_only_out, graph_only_out, graph_vector_out]) - gr.Markdown("""## 3. Others """) + gr.Markdown("""## 3. Others (🚧) """) with gr.Row(): - inp = [] + with gr.Column(): + inp = gr.Textbox(value="g.V().limit(10)", label="Gremlin query", show_copy_button=True) + format = gr.Checkbox(label="Format JSON", value=True) out = gr.Textbox(label="Output", show_copy_button=True) - btn = gr.Button("Initialize HugeGraph test data") - btn.click(fn=init_hg_test_data, inputs=inp, outputs=out) # pylint: disable=no-member + btn = gr.Button("Run gremlin query on HugeGraph") + btn.click(fn=run_gremlin_query, inputs=[inp, format], outputs=out) # pylint: disable=no-member with gr.Row(): - inp = gr.Textbox(value="g.V().limit(10)", label="Gremlin query") + inp = [] out = gr.Textbox(label="Output", show_copy_button=True) - btn = gr.Button("Run gremlin query on HugeGraph") - btn.click(fn=run_gremlin_query, inputs=inp, outputs=out) # pylint: disable=no-member + btn = gr.Button("(BETA) Init HugeGraph test data (🚧WIP)") + btn.click(fn=init_hg_test_data, inputs=inp, outputs=out) # pylint: disable=no-member + app = gr.mount_gradio_app(app, hugegraph_llm, path="/") # Note: set reload to False in production environment uvicorn.run(app, host=args.host, port=args.port) diff --git a/hugegraph-llm/src/hugegraph_llm/utils/hugegraph_utils.py b/hugegraph-llm/src/hugegraph_llm/utils/hugegraph_utils.py index e1d51cbc..d942bf93 100644 --- a/hugegraph-llm/src/hugegraph_llm/utils/hugegraph_utils.py +++ b/hugegraph-llm/src/hugegraph_llm/utils/hugegraph_utils.py @@ -14,15 +14,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - +import json from pyhugegraph.client import PyHugeClient from hugegraph_llm.config import settings -def run_gremlin_query(query): +def run_gremlin_query(query, format=False): res = get_hg_client().gremlin().exec(query) - return res + return json.dumps(res, indent=4, ensure_ascii=False) if format else res def get_hg_client(): From d7c8a212cd1095e0659565dc5e052e72e79ee848 Mon Sep 17 00:00:00 2001 From: vichayturen <1073931273@qq.com> Date: Mon, 12 Aug 2024 01:27:42 +0800 Subject: [PATCH 10/11] chore: change todos in rerank method --- .../hugegraph_llm/operators/common_op/merge_dedup_rerank.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/hugegraph-llm/src/hugegraph_llm/operators/common_op/merge_dedup_rerank.py b/hugegraph-llm/src/hugegraph_llm/operators/common_op/merge_dedup_rerank.py index af6ce6a6..19ad4e47 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/common_op/merge_dedup_rerank.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/common_op/merge_dedup_rerank.py @@ -35,10 +35,7 @@ def __init__(self, embedding: BaseEmbedding, topk: int = 10): self.topk = topk def run(self, context: Dict[str, Any]) -> Dict[str, Any]: - # TODO: 逻辑应该是: - # 1. 分词后优先关键词匹配 vid (直接匹配) - # 2. 匹配不到,尝试退化模糊vid召回(并提示用户是否想问的是xxx) - # 3. 之后我们可以把chunk/graph加上,查询的时候可以同时查询关键词对应的chunk,然后一起综合(进阶) + # TODO: exact > fuzzy; vertex > 1-depth-neighbour > 2-depth-neighbour; priority vertices query = context.get("query") vector_result = context.get("vector_result", []) From 1a80c7822904b7c815a6313862b927cc08b7466b Mon Sep 17 00:00:00 2001 From: vichayturen <1073931273@qq.com> Date: Mon, 12 Aug 2024 14:19:35 +0800 Subject: [PATCH 11/11] fix code style --- style/pylint.conf | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/style/pylint.conf b/style/pylint.conf index 2ae32816..80ebc616 100644 --- a/style/pylint.conf +++ b/style/pylint.conf @@ -337,7 +337,7 @@ indent-after-paren=4 indent-string=' ' # Maximum number of characters on a single line. -max-line-length=100 +max-line-length=120 # Maximum number of lines in a module. max-module-lines=1000