diff --git a/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py b/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py index 878a2f30..c63ee51e 100644 --- a/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py +++ b/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py @@ -29,21 +29,23 @@ class RAGRequest(BaseModel): graph_ratio: float = 0.5 rerank_method: Literal["bleu", "reranker"] = "bleu" near_neighbor_first: bool = False - with_gremlin_template: bool = True + with_gremlin_tmpl: bool = True custom_priority_info: str = "" answer_prompt: Optional[str] = None - num_gremlin_generate_example: int = 1 keywords_extract_prompt: Optional[str] = None + gremlin_tmpl_num: int = 1 + gremlin_prompt: Optional[str] = "" class GraphRAGRequest(BaseModel): query: str = "" gremlin_tmpl_num: int = 1 with_gremlin_tmpl: bool = True - answer_prompt: Optional[str] = "" # FIXME: read from prompt_settings? + answer_prompt: Optional[str] = "" # FIXME: read from prompt_settings? rerank_method: Literal["bleu", "reranker"] = "bleu" near_neighbor_first: bool = False custom_priority_info: str = "" + gremlin_prompt: Optional[str] = "" class GraphConfigRequest(BaseModel): @@ -80,4 +82,4 @@ class RerankerConfigRequest(BaseModel): class LogStreamRequest(BaseModel): admin_token: Optional[str] = None - log_file: Optional[str] = 'llm-server.log' + log_file: Optional[str] = "llm-server.log" diff --git a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py b/hugegraph-llm/src/hugegraph_llm/api/rag_api.py index a324728a..1c394a2d 100644 --- a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py +++ b/hugegraph-llm/src/hugegraph_llm/api/rag_api.py @@ -15,9 +15,9 @@ # specific language governing permissions and limitations # under the License. import json -from typing import Literal +from typing import Literal, Optional -from fastapi import status, APIRouter, HTTPException +from fastapi import status, APIRouter, HTTPException, Query from hugegraph_llm.api.exceptions.rag_exceptions import generate_response from hugegraph_llm.api.models.rag_requests import ( @@ -36,17 +36,20 @@ def graph_rag_recall( query: str, gremlin_tmpl_num: int, with_gremlin_tmpl: bool, - answer_prompt: str, # FIXME: text2gremlin should use it + answer_prompt: str, rerank_method: Literal["bleu", "reranker"], near_neighbor_first: bool, custom_related_information: str, + gremlin_prompt: str, ) -> dict: from hugegraph_llm.operators.graph_rag_task import RAGPipeline rag = RAGPipeline() rag.extract_keywords().keywords_to_vid().import_schema(huge_settings.graph_name).query_graphdb( - with_gremlin_template=with_gremlin_tmpl, num_gremlin_generate_example=gremlin_tmpl_num + with_gremlin_template=with_gremlin_tmpl, + num_gremlin_generate_example=gremlin_tmpl_num, + gremlin_prompt=gremlin_prompt, ).merge_dedup_rerank( rerank_method=rerank_method, near_neighbor_first=near_neighbor_first, @@ -60,20 +63,44 @@ def rag_http_api( router: APIRouter, rag_answer_func, apply_graph_conf, apply_llm_conf, apply_embedding_conf, apply_reranker_conf ): @router.post("/rag", status_code=status.HTTP_200_OK) - def rag_answer_api(req: RAGRequest): + def rag_answer_api( + req: RAGRequest, + query_param: str = Query("", description="Query you want to ask"), + raw_answer_param: bool = Query(False, description="Use LLM to generate answer directly"), + vector_only_param: bool = Query(False, description="Use LLM to generate answer with vector"), + graph_only_param: bool = Query(False, description="Use LLM to generate answer with graph RAG only"), + graph_vector_answer_param: bool = Query(False, description="Use LLM to generate answer with vector & GraphRAG"), + with_gremlin_tmpl_param: bool = Query(True, description="Use exapmle template in text2gremlin"), + graph_ratio_param: float = Query(0.5, description="The ratio of GraphRAG ans & vector ans"), + rerank_method_param: Literal["bleu", "reranker"] = Query("bleu", description="Method to rerank the results."), + near_neighbor_first_param: bool = Query(False, description="Prioritize near neighbors in the search results."), + custom_priority_info_param: str = Query("", description="Custom information to prioritize certain results."), + answer_prompt_param: Optional[str] = Query( + prompt.answer_prompt, description="Prompt to guide the answer generation." + ), + keywords_extract_prompt_param: Optional[str] = Query( + prompt.keywords_extract_prompt, description="Prompt for extracting keywords from the query." + ), + gremlin_tmpl_num_param: int = Query(1, description="Number of Gremlin templates to use."), + gremlin_prompt_param: Optional[str] = Query( + prompt.gremlin_generate_prompt, description="Prompt for the Gremlin query. Don't change it casually" + ), + ): result = rag_answer_func( req.query, req.raw_answer, req.vector_only, req.graph_only, req.graph_vector_answer, - req.with_gremlin_template, + req.with_gremlin_tmpl, req.graph_ratio, req.rerank_method, req.near_neighbor_first, req.custom_priority_info, req.answer_prompt or prompt.answer_prompt, req.keywords_extract_prompt or prompt.keywords_extract_prompt, + req.gremlin_tmpl_num, + req.gremlin_prompt or prompt.gremlin_generate_prompt, ) return { key: value @@ -82,16 +109,31 @@ def rag_answer_api(req: RAGRequest): } @router.post("/rag/graph", status_code=status.HTTP_200_OK) - def graph_rag_recall_api(req: GraphRAGRequest): + def graph_rag_recall_api( + req: GraphRAGRequest, + query_param: str = Query("", description="Query you want to ask"), + with_gremlin_templ_param: bool = Query(True, description="Use exapmle template in text2gremlin"), + rerank_method_param: Literal["bleu", "reranker"] = Query("bleu", description="Method to rerank the results."), + near_neighbor_first_param: bool = Query(False, description="Prioritize near neighbors in the search results."), + custom_priority_info_param: str = Query("", description="Custom information to prioritize certain results."), + answer_prompt_param: Optional[str] = Query( + prompt.answer_prompt, description="Prompt to guide the answer generation." + ), + gremlin_tmpl_num_param: int = Query(1, description="Number of Gremlin templates to use."), + gremlin_prompt_param: Optional[str] = Query( + prompt.gremlin_generate_prompt, description="Prompt for the Gremlin query. Don't change it casually" + ), + ): try: result = graph_rag_recall( query=req.query, gremlin_tmpl_num=req.gremlin_tmpl_num, with_gremlin_tmpl=req.with_gremlin_tmpl, - answer_prompt=req.answer_prompt, + answer_prompt=req.answer_prompt or prompt.answer_prompt, rerank_method=req.rerank_method, near_neighbor_first=req.near_neighbor_first, custom_related_information=req.custom_priority_info, + gremlin_prompt=req.gremlin_prompt or prompt.gremlin_generate_prompt, ) if isinstance(result, dict): diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/rag_block.py b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/rag_block.py index 3edaaf52..070f3b3e 100644 --- a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/rag_block.py +++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/rag_block.py @@ -30,18 +30,20 @@ def rag_answer( - text: str, - raw_answer: bool, - vector_only_answer: bool, - graph_only_answer: bool, - graph_vector_answer: bool, - with_gremlin_template: bool, - graph_ratio: float, - rerank_method: Literal["bleu", "reranker"], - near_neighbor_first: bool, - custom_related_information: str, - answer_prompt: str, - keywords_extract_prompt: str, + text: str, + raw_answer: bool, + vector_only_answer: bool, + graph_only_answer: bool, + graph_vector_answer: bool, + with_gremlin_template: bool, + graph_ratio: float, + rerank_method: Literal["bleu", "reranker"], + near_neighbor_first: bool, + custom_related_information: str, + answer_prompt: str, + keywords_extract_prompt: str, + gremlin_tmpl_num: Optional[int] = 2, + gremlin_prompt: Optional[str] = prompt.gremlin_generate_prompt, ) -> Tuple: """ Generate an answer using the RAG (Retrieval-Augmented Generation) pipeline. @@ -52,15 +54,17 @@ def rag_answer( 5. Run the pipeline and return the results. """ should_update_prompt = ( - prompt.default_question != text or - prompt.answer_prompt != answer_prompt or - prompt.keywords_extract_prompt != keywords_extract_prompt + prompt.default_question != text + or prompt.answer_prompt != answer_prompt + or prompt.keywords_extract_prompt != keywords_extract_prompt + or prompt.gremlin_generate_prompt != gremlin_prompt ) if should_update_prompt or prompt.custom_rerank_info != custom_related_information: prompt.custom_rerank_info = custom_related_information prompt.default_question = text prompt.answer_prompt = answer_prompt prompt.keywords_extract_prompt = keywords_extract_prompt + prompt.gremlin_generate_prompt = gremlin_prompt prompt.update_yaml_file() vector_search = vector_only_answer or graph_vector_answer @@ -74,9 +78,18 @@ def rag_answer( rag.query_vector_index() if graph_search: rag.extract_keywords(extract_template=keywords_extract_prompt).keywords_to_vid().import_schema( - huge_settings.graph_name).query_graphdb(with_gremlin_template=with_gremlin_template) + huge_settings.graph_name + ).query_graphdb( + with_gremlin_template=with_gremlin_template, + num_gremlin_generate_example=gremlin_tmpl_num, + gremlin_prompt=gremlin_prompt, + ) # TODO: add more user-defined search strategies - rag.merge_dedup_rerank(graph_ratio, rerank_method, near_neighbor_first, ) + rag.merge_dedup_rerank( + graph_ratio, + rerank_method, + near_neighbor_first, + ) rag.synthesize_answer(raw_answer, vector_only_answer, graph_only_answer, graph_vector_answer, answer_prompt) try: @@ -123,6 +136,7 @@ def create_rag_block(): graph_vector_radio = gr.Radio(choices=[True, False], value=False, label="Graph-Vector Answer") with gr.Row(): with_gremlin_template_radio = gr.Radio(choices=[True, False], value=True, label="With Gremlin Template") + def toggle_slider(enable): return gr.update(interactive=enable) @@ -164,16 +178,18 @@ def toggle_slider(enable): near_neighbor_first, custom_related_information, answer_prompt_input, - keywords_extract_prompt_input + keywords_extract_prompt_input, ], outputs=[raw_out, vector_only_out, graph_only_out, graph_vector_out], ) - gr.Markdown("""## 2. (Batch) Back-testing ) + gr.Markdown( + """## 2. (Batch) Back-testing ) > 1. Download the template file & fill in the questions you want to test. > 2. Upload the file & click the button to generate answers. (Preview shows the first 40 lines) > 3. The answer options are the same as the above RAG/Q&A frame - """) + """ + ) tests_df_headers = [ "Question", "Expected Answer", @@ -214,18 +230,19 @@ def change_showing_excel(line_count): return df def several_rag_answer( - is_raw_answer: bool, - is_vector_only_answer: bool, - is_graph_only_answer: bool, - is_graph_vector_answer: bool, - graph_ratio: float, - rerank_method: Literal["bleu", "reranker"], - near_neighbor_first: bool, - custom_related_information: str, - answer_prompt: str, - keywords_extract_prompt: str, - progress=gr.Progress(track_tqdm=True), - answer_max_line_count: int = 1, + is_raw_answer: bool, + is_vector_only_answer: bool, + is_graph_only_answer: bool, + is_graph_vector_answer: bool, + graph_ratio: float, + rerank_method: Literal["bleu", "reranker"], + near_neighbor_first: bool, + with_gremlin_template: bool, + custom_related_information: str, + answer_prompt: str, + keywords_extract_prompt: str, + answer_max_line_count: int = 1, + progress=gr.Progress(track_tqdm=True), ): df = pd.read_excel(questions_path, dtype=str) total_rows = len(df) @@ -240,6 +257,7 @@ def several_rag_answer( graph_ratio, rerank_method, near_neighbor_first, + with_gremlin_template, custom_related_information, answer_prompt, keywords_extract_prompt, @@ -273,6 +291,7 @@ def several_rag_answer( graph_ratio, rerank_method, near_neighbor_first, + with_gremlin_template_radio, custom_related_information, answer_prompt_input, keywords_extract_prompt_input, diff --git a/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py b/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py index a71ed018..03ac9ae0 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py @@ -32,6 +32,7 @@ from hugegraph_llm.operators.llm_op.answer_synthesize import AnswerSynthesize from hugegraph_llm.operators.llm_op.keyword_extract import KeywordExtract from hugegraph_llm.utils.decorators import log_time, log_operator_time, record_qps +from hugegraph_llm.config import prompt class RAGPipeline: @@ -126,6 +127,7 @@ def query_graphdb( prop_to_match: Optional[str] = None, with_gremlin_template: bool = True, num_gremlin_generate_example: int = 1, + gremlin_prompt: Optional[str] = prompt.gremlin_generate_prompt, ): """ Add a graph RAG query operator to the pipeline. @@ -146,6 +148,7 @@ def query_graphdb( prop_to_match=prop_to_match, with_gremlin_template=with_gremlin_template, num_gremlin_generate_example=num_gremlin_generate_example, + gremlin_prompt=gremlin_prompt, ) ) return self diff --git a/hugegraph-llm/src/hugegraph_llm/operators/gremlin_generate_task.py b/hugegraph-llm/src/hugegraph_llm/operators/gremlin_generate_task.py index dfbf085e..95ce59f0 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/gremlin_generate_task.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/gremlin_generate_task.py @@ -58,8 +58,9 @@ def example_index_query(self, num_examples): self.operators.append(GremlinExampleIndexQuery(self.embedding, num_examples)) return self - def gremlin_generate_synthesize(self, schema, gremlin_prompt: Optional[str] = None, - vertices: Optional[List[str]] = None): + def gremlin_generate_synthesize( + self, schema, gremlin_prompt: Optional[str] = None, vertices: Optional[List[str]] = None + ): self.operators.append(GremlinGenerateSynthesize(self.llm, schema, vertices, gremlin_prompt)) return self diff --git a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py index 33190b72..a4186c82 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py @@ -18,7 +18,7 @@ import json from typing import Any, Dict, Optional, List, Set, Tuple -from hugegraph_llm.config import huge_settings +from hugegraph_llm.config import huge_settings, prompt from hugegraph_llm.models.embeddings.base import BaseEmbedding from hugegraph_llm.models.llms.base import BaseLLM from hugegraph_llm.operators.gremlin_generate_task import GremlinGenerator @@ -76,16 +76,17 @@ class GraphRAGQuery: def __init__( - self, - max_deep: int = 2, - max_items: int = int(huge_settings.max_items), - prop_to_match: Optional[str] = None, - llm: Optional[BaseLLM] = None, - embedding: Optional[BaseEmbedding] = None, - max_v_prop_len: int = 2048, - max_e_prop_len: int = 256, - with_gremlin_template: bool = True, - num_gremlin_generate_example: int = 1 + self, + max_deep: int = 2, + max_items: int = int(huge_settings.max_items), + prop_to_match: Optional[str] = None, + llm: Optional[BaseLLM] = None, + embedding: Optional[BaseEmbedding] = None, + max_v_prop_len: int = 2048, + max_e_prop_len: int = 256, + with_gremlin_template: bool = True, + num_gremlin_generate_example: int = 1, + gremlin_prompt: Optional[str] = None, ): self._client = PyHugeClient( huge_settings.graph_ip, @@ -108,6 +109,7 @@ def __init__( ) self._num_gremlin_generate_example = num_gremlin_generate_example self._with_gremlin_template = with_gremlin_template + self._gremlin_prompt = gremlin_prompt or prompt.gremlin_generate_prompt def run(self, context: Dict[str, Any]) -> Dict[str, Any]: self._init_client(context) @@ -134,12 +136,8 @@ def _gremlin_generate_query(self, context: Dict[str, Any]) -> Dict[str, Any]: self._gremlin_generator.clear() self._gremlin_generator.example_index_query(num_examples=self._num_gremlin_generate_example) gremlin_response = self._gremlin_generator.gremlin_generate_synthesize( - context["simple_schema"], - vertices=vertices, - ).run( - query=query, - query_embedding=query_embedding - ) + context["simple_schema"], vertices=vertices, gremlin_prompt=self._gremlin_prompt + ).run(query=query, query_embedding=query_embedding) if self._with_gremlin_template: gremlin = gremlin_response["result"] else: @@ -154,10 +152,9 @@ def _gremlin_generate_query(self, context: Dict[str, Any]) -> Dict[str, Any]: if context["graph_result"]: context["graph_result_flag"] = 1 context["graph_context_head"] = ( - f"The following are graph query result " - f"from gremlin query `{gremlin}`.\n" + f"The following are graph query result " f"from gremlin query `{gremlin}`.\n" ) - except Exception as e: # pylint: disable=broad-except + except Exception as e: # pylint: disable=broad-except log.error(e) context["graph_result"] = "" return context @@ -285,8 +282,9 @@ def _format_graph_query_result(self, query_paths) -> Tuple[Set[str], List[Set[st return subgraph, vertex_degree_list, subgraph_with_degree - def _process_path(self, path: Any, use_id_to_match: bool, v_cache: Set[str], - e_cache: Set[Tuple[str, str, str]]) -> Tuple[str, List[str]]: + def _process_path( + self, path: Any, use_id_to_match: bool, v_cache: Set[str], e_cache: Set[Tuple[str, str, str]] + ) -> Tuple[str, List[str]]: flat_rel = "" raw_flat_rel = path["objects"] assert len(raw_flat_rel) % 2 == 1, "The length of raw_flat_rel should be odd." @@ -300,8 +298,7 @@ def _process_path(self, path: Any, use_id_to_match: bool, v_cache: Set[str], if i % 2 == 0: # Process each vertex flat_rel, prior_edge_str_len, depth = self._process_vertex( - item, flat_rel, node_cache, prior_edge_str_len, depth, nodes_with_degree, use_id_to_match, - v_cache + item, flat_rel, node_cache, prior_edge_str_len, depth, nodes_with_degree, use_id_to_match, v_cache ) else: # Process each edge @@ -311,17 +308,24 @@ def _process_path(self, path: Any, use_id_to_match: bool, v_cache: Set[str], return flat_rel, nodes_with_degree - def _process_vertex(self, item: Any, flat_rel: str, node_cache: Set[str], - prior_edge_str_len: int, depth: int, nodes_with_degree: List[str], - use_id_to_match: bool, v_cache: Set[str]) -> Tuple[str, int, int]: + def _process_vertex( + self, + item: Any, + flat_rel: str, + node_cache: Set[str], + prior_edge_str_len: int, + depth: int, + nodes_with_degree: List[str], + use_id_to_match: bool, + v_cache: Set[str], + ) -> Tuple[str, int, int]: matched_str = item["id"] if use_id_to_match else item["props"][self._prop_to_match] if matched_str in node_cache: flat_rel = flat_rel[:-prior_edge_str_len] return flat_rel, prior_edge_str_len, depth node_cache.add(matched_str) - props_str = ", ".join(f"{k}: {self._limit_property_query(v, 'v')}" - for k, v in item["props"].items() if v) + props_str = ", ".join(f"{k}: {self._limit_property_query(v, 'v')}" for k, v in item["props"].items() if v) # TODO: we may remove label id or replace with label name if matched_str in v_cache: @@ -335,20 +339,27 @@ def _process_vertex(self, item: Any, flat_rel: str, node_cache: Set[str], depth += 1 return flat_rel, prior_edge_str_len, depth - def _process_edge(self, item: Any, path_str: str, raw_flat_rel: List[Any], i: int, use_id_to_match: bool, - e_cache: Set[Tuple[str, str, str]]) -> Tuple[str, int]: - props_str = ", ".join(f"{k}: {self._limit_property_query(v, 'e')}" - for k, v in item["props"].items() if v) + def _process_edge( + self, + item: Any, + path_str: str, + raw_flat_rel: List[Any], + i: int, + use_id_to_match: bool, + e_cache: Set[Tuple[str, str, str]], + ) -> Tuple[str, int]: + props_str = ", ".join(f"{k}: {self._limit_property_query(v, 'e')}" for k, v in item["props"].items() if v) props_str = f"{{{props_str}}}" if props_str else "" - prev_matched_str = raw_flat_rel[i - 1]["id"] if use_id_to_match else ( - raw_flat_rel)[i - 1]["props"][self._prop_to_match] + prev_matched_str = ( + raw_flat_rel[i - 1]["id"] if use_id_to_match else (raw_flat_rel)[i - 1]["props"][self._prop_to_match] + ) - edge_key = (item['inV'], item['label'], item['outV']) + edge_key = (item["inV"], item["label"], item["outV"]) if edge_key not in e_cache: e_cache.add(edge_key) edge_label = f"{item['label']}{props_str}" else: - edge_label = item['label'] + edge_label = item["label"] edge_str = f"--[{edge_label}]-->" if item["outV"] == prev_matched_str else f"<--[{edge_label}]--" path_str += edge_str @@ -365,8 +376,8 @@ def _extract_labels_from_schema(self) -> Tuple[List[str], List[str]]: schema = self._get_graph_schema() vertex_props_str, edge_props_str = schema.split("\n")[:2] # TODO: rename to vertex (also need update in the schema) - vertex_props_str = vertex_props_str[len("Vertex properties: "):].strip("[").strip("]") - edge_props_str = edge_props_str[len("Edge properties: "):].strip("[").strip("]") + vertex_props_str = vertex_props_str[len("Vertex properties: ") :].strip("[").strip("]") + edge_props_str = edge_props_str[len("Edge properties: ") :].strip("[").strip("]") vertex_labels = self._extract_label_names(vertex_props_str) edge_labels = self._extract_label_names(edge_props_str) return vertex_labels, edge_labels