Skip to content

Commit

Permalink
fix(llm): add text in /docs & fix /rag api
Browse files Browse the repository at this point in the history
  • Loading branch information
HJ-Young committed Dec 18, 2024
1 parent 1777045 commit f5992cf
Show file tree
Hide file tree
Showing 6 changed files with 163 additions and 85 deletions.
10 changes: 6 additions & 4 deletions hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,23 @@ class RAGRequest(BaseModel):
graph_ratio: float = 0.5
rerank_method: Literal["bleu", "reranker"] = "bleu"
near_neighbor_first: bool = False
with_gremlin_template: bool = True
with_gremlin_tmpl: bool = True
custom_priority_info: str = ""
answer_prompt: Optional[str] = None
num_gremlin_generate_example: int = 1
keywords_extract_prompt: Optional[str] = None
gremlin_tmpl_num: int = 1
gremlin_prompt: Optional[str] = ""


class GraphRAGRequest(BaseModel):
query: str = ""
gremlin_tmpl_num: int = 1
with_gremlin_tmpl: bool = True
answer_prompt: Optional[str] = "" # FIXME: read from prompt_settings?
answer_prompt: Optional[str] = "" # FIXME: read from prompt_settings?
rerank_method: Literal["bleu", "reranker"] = "bleu"
near_neighbor_first: bool = False
custom_priority_info: str = ""
gremlin_prompt: Optional[str] = ""


class GraphConfigRequest(BaseModel):
Expand Down Expand Up @@ -80,4 +82,4 @@ class RerankerConfigRequest(BaseModel):

class LogStreamRequest(BaseModel):
admin_token: Optional[str] = None
log_file: Optional[str] = 'llm-server.log'
log_file: Optional[str] = "llm-server.log"
58 changes: 50 additions & 8 deletions hugegraph-llm/src/hugegraph_llm/api/rag_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
# specific language governing permissions and limitations
# under the License.
import json
from typing import Literal
from typing import Literal, Optional

from fastapi import status, APIRouter, HTTPException
from fastapi import status, APIRouter, HTTPException, Query

from hugegraph_llm.api.exceptions.rag_exceptions import generate_response
from hugegraph_llm.api.models.rag_requests import (
Expand All @@ -36,17 +36,20 @@ def graph_rag_recall(
query: str,
gremlin_tmpl_num: int,
with_gremlin_tmpl: bool,
answer_prompt: str, # FIXME: text2gremlin should use it
answer_prompt: str,
rerank_method: Literal["bleu", "reranker"],
near_neighbor_first: bool,
custom_related_information: str,
gremlin_prompt: str,
) -> dict:
from hugegraph_llm.operators.graph_rag_task import RAGPipeline

rag = RAGPipeline()

rag.extract_keywords().keywords_to_vid().import_schema(huge_settings.graph_name).query_graphdb(
with_gremlin_template=with_gremlin_tmpl, num_gremlin_generate_example=gremlin_tmpl_num
with_gremlin_template=with_gremlin_tmpl,
num_gremlin_generate_example=gremlin_tmpl_num,
gremlin_prompt=gremlin_prompt,
).merge_dedup_rerank(
rerank_method=rerank_method,
near_neighbor_first=near_neighbor_first,
Expand All @@ -60,20 +63,44 @@ def rag_http_api(
router: APIRouter, rag_answer_func, apply_graph_conf, apply_llm_conf, apply_embedding_conf, apply_reranker_conf
):
@router.post("/rag", status_code=status.HTTP_200_OK)
def rag_answer_api(req: RAGRequest):
def rag_answer_api(
req: RAGRequest,
query_param: str = Query("", description="Query you want to ask"),
raw_answer_param: bool = Query(False, description="Use LLM to generate answer directly"),
vector_only_param: bool = Query(False, description="Use LLM to generate answer with vector"),
graph_only_param: bool = Query(False, description="Use LLM to generate answer with graph RAG only"),
graph_vector_answer_param: bool = Query(False, description="Use LLM to generate answer with vector & GraphRAG"),
with_gremlin_tmpl_param: bool = Query(True, description="Use exapmle template in text2gremlin"),
graph_ratio_param: float = Query(0.5, description="The ratio of GraphRAG ans & vector ans"),
rerank_method_param: Literal["bleu", "reranker"] = Query("bleu", description="Method to rerank the results."),
near_neighbor_first_param: bool = Query(False, description="Prioritize near neighbors in the search results."),
custom_priority_info_param: str = Query("", description="Custom information to prioritize certain results."),
answer_prompt_param: Optional[str] = Query(
prompt.answer_prompt, description="Prompt to guide the answer generation."
),
keywords_extract_prompt_param: Optional[str] = Query(
prompt.keywords_extract_prompt, description="Prompt for extracting keywords from the query."
),
gremlin_tmpl_num_param: int = Query(1, description="Number of Gremlin templates to use."),
gremlin_prompt_param: Optional[str] = Query(
prompt.gremlin_generate_prompt, description="Prompt for the Gremlin query. Don't change it casually"
),
):
result = rag_answer_func(
req.query,
req.raw_answer,
req.vector_only,
req.graph_only,
req.graph_vector_answer,
req.with_gremlin_template,
req.with_gremlin_tmpl,
req.graph_ratio,
req.rerank_method,
req.near_neighbor_first,
req.custom_priority_info,
req.answer_prompt or prompt.answer_prompt,
req.keywords_extract_prompt or prompt.keywords_extract_prompt,
req.gremlin_tmpl_num,
req.gremlin_prompt or prompt.gremlin_generate_prompt,
)
return {
key: value
Expand All @@ -82,16 +109,31 @@ def rag_answer_api(req: RAGRequest):
}

@router.post("/rag/graph", status_code=status.HTTP_200_OK)
def graph_rag_recall_api(req: GraphRAGRequest):
def graph_rag_recall_api(
req: GraphRAGRequest,
query_param: str = Query("", description="Query you want to ask"),
with_gremlin_templ_param: bool = Query(True, description="Use exapmle template in text2gremlin"),
rerank_method_param: Literal["bleu", "reranker"] = Query("bleu", description="Method to rerank the results."),
near_neighbor_first_param: bool = Query(False, description="Prioritize near neighbors in the search results."),
custom_priority_info_param: str = Query("", description="Custom information to prioritize certain results."),
answer_prompt_param: Optional[str] = Query(
prompt.answer_prompt, description="Prompt to guide the answer generation."
),
gremlin_tmpl_num_param: int = Query(1, description="Number of Gremlin templates to use."),
gremlin_prompt_param: Optional[str] = Query(
prompt.gremlin_generate_prompt, description="Prompt for the Gremlin query. Don't change it casually"
),
):
try:
result = graph_rag_recall(
query=req.query,
gremlin_tmpl_num=req.gremlin_tmpl_num,
with_gremlin_tmpl=req.with_gremlin_tmpl,
answer_prompt=req.answer_prompt,
answer_prompt=req.answer_prompt or prompt.answer_prompt,
rerank_method=req.rerank_method,
near_neighbor_first=req.near_neighbor_first,
custom_related_information=req.custom_priority_info,
gremlin_prompt=req.gremlin_prompt or prompt.gremlin_generate_prompt,
)

if isinstance(result, dict):
Expand Down
83 changes: 51 additions & 32 deletions hugegraph-llm/src/hugegraph_llm/demo/rag_demo/rag_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,20 @@


def rag_answer(
text: str,
raw_answer: bool,
vector_only_answer: bool,
graph_only_answer: bool,
graph_vector_answer: bool,
with_gremlin_template: bool,
graph_ratio: float,
rerank_method: Literal["bleu", "reranker"],
near_neighbor_first: bool,
custom_related_information: str,
answer_prompt: str,
keywords_extract_prompt: str,
text: str,
raw_answer: bool,
vector_only_answer: bool,
graph_only_answer: bool,
graph_vector_answer: bool,
with_gremlin_template: bool,
graph_ratio: float,
rerank_method: Literal["bleu", "reranker"],
near_neighbor_first: bool,
custom_related_information: str,
answer_prompt: str,
keywords_extract_prompt: str,
gremlin_tmpl_num: Optional[int] = 2,
gremlin_prompt: Optional[str] = prompt.gremlin_generate_prompt,
) -> Tuple:
"""
Generate an answer using the RAG (Retrieval-Augmented Generation) pipeline.
Expand All @@ -52,15 +54,17 @@ def rag_answer(
5. Run the pipeline and return the results.
"""
should_update_prompt = (
prompt.default_question != text or
prompt.answer_prompt != answer_prompt or
prompt.keywords_extract_prompt != keywords_extract_prompt
prompt.default_question != text
or prompt.answer_prompt != answer_prompt
or prompt.keywords_extract_prompt != keywords_extract_prompt
or prompt.gremlin_generate_prompt != gremlin_prompt
)
if should_update_prompt or prompt.custom_rerank_info != custom_related_information:
prompt.custom_rerank_info = custom_related_information
prompt.default_question = text
prompt.answer_prompt = answer_prompt
prompt.keywords_extract_prompt = keywords_extract_prompt
prompt.gremlin_generate_prompt = gremlin_prompt
prompt.update_yaml_file()

vector_search = vector_only_answer or graph_vector_answer
Expand All @@ -74,9 +78,18 @@ def rag_answer(
rag.query_vector_index()
if graph_search:
rag.extract_keywords(extract_template=keywords_extract_prompt).keywords_to_vid().import_schema(
huge_settings.graph_name).query_graphdb(with_gremlin_template=with_gremlin_template)
huge_settings.graph_name
).query_graphdb(
with_gremlin_template=with_gremlin_template,
num_gremlin_generate_example=gremlin_tmpl_num,
gremlin_prompt=gremlin_prompt,
)
# TODO: add more user-defined search strategies
rag.merge_dedup_rerank(graph_ratio, rerank_method, near_neighbor_first, )
rag.merge_dedup_rerank(
graph_ratio,
rerank_method,
near_neighbor_first,
)
rag.synthesize_answer(raw_answer, vector_only_answer, graph_only_answer, graph_vector_answer, answer_prompt)

try:
Expand Down Expand Up @@ -123,6 +136,7 @@ def create_rag_block():
graph_vector_radio = gr.Radio(choices=[True, False], value=False, label="Graph-Vector Answer")
with gr.Row():
with_gremlin_template_radio = gr.Radio(choices=[True, False], value=True, label="With Gremlin Template")

def toggle_slider(enable):
return gr.update(interactive=enable)

Expand Down Expand Up @@ -164,16 +178,18 @@ def toggle_slider(enable):
near_neighbor_first,
custom_related_information,
answer_prompt_input,
keywords_extract_prompt_input
keywords_extract_prompt_input,
],
outputs=[raw_out, vector_only_out, graph_only_out, graph_vector_out],
)

gr.Markdown("""## 2. (Batch) Back-testing )
gr.Markdown(
"""## 2. (Batch) Back-testing )
> 1. Download the template file & fill in the questions you want to test.
> 2. Upload the file & click the button to generate answers. (Preview shows the first 40 lines)
> 3. The answer options are the same as the above RAG/Q&A frame
""")
"""
)
tests_df_headers = [
"Question",
"Expected Answer",
Expand Down Expand Up @@ -214,18 +230,19 @@ def change_showing_excel(line_count):
return df

def several_rag_answer(
is_raw_answer: bool,
is_vector_only_answer: bool,
is_graph_only_answer: bool,
is_graph_vector_answer: bool,
graph_ratio: float,
rerank_method: Literal["bleu", "reranker"],
near_neighbor_first: bool,
custom_related_information: str,
answer_prompt: str,
keywords_extract_prompt: str,
progress=gr.Progress(track_tqdm=True),
answer_max_line_count: int = 1,
is_raw_answer: bool,
is_vector_only_answer: bool,
is_graph_only_answer: bool,
is_graph_vector_answer: bool,
graph_ratio: float,
rerank_method: Literal["bleu", "reranker"],
near_neighbor_first: bool,
with_gremlin_template: bool,
custom_related_information: str,
answer_prompt: str,
keywords_extract_prompt: str,
answer_max_line_count: int = 1,
progress=gr.Progress(track_tqdm=True),
):
df = pd.read_excel(questions_path, dtype=str)
total_rows = len(df)
Expand All @@ -240,6 +257,7 @@ def several_rag_answer(
graph_ratio,
rerank_method,
near_neighbor_first,
with_gremlin_template,
custom_related_information,
answer_prompt,
keywords_extract_prompt,
Expand Down Expand Up @@ -273,6 +291,7 @@ def several_rag_answer(
graph_ratio,
rerank_method,
near_neighbor_first,
with_gremlin_template_radio,
custom_related_information,
answer_prompt_input,
keywords_extract_prompt_input,
Expand Down
3 changes: 3 additions & 0 deletions hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from hugegraph_llm.operators.llm_op.answer_synthesize import AnswerSynthesize
from hugegraph_llm.operators.llm_op.keyword_extract import KeywordExtract
from hugegraph_llm.utils.decorators import log_time, log_operator_time, record_qps
from hugegraph_llm.config import prompt


class RAGPipeline:
Expand Down Expand Up @@ -126,6 +127,7 @@ def query_graphdb(
prop_to_match: Optional[str] = None,
with_gremlin_template: bool = True,
num_gremlin_generate_example: int = 1,
gremlin_prompt: Optional[str] = prompt.gremlin_generate_prompt,
):
"""
Add a graph RAG query operator to the pipeline.
Expand All @@ -146,6 +148,7 @@ def query_graphdb(
prop_to_match=prop_to_match,
with_gremlin_template=with_gremlin_template,
num_gremlin_generate_example=num_gremlin_generate_example,
gremlin_prompt=gremlin_prompt,
)
)
return self
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,9 @@ def example_index_query(self, num_examples):
self.operators.append(GremlinExampleIndexQuery(self.embedding, num_examples))
return self

def gremlin_generate_synthesize(self, schema, gremlin_prompt: Optional[str] = None,
vertices: Optional[List[str]] = None):
def gremlin_generate_synthesize(
self, schema, gremlin_prompt: Optional[str] = None, vertices: Optional[List[str]] = None
):
self.operators.append(GremlinGenerateSynthesize(self.llm, schema, vertices, gremlin_prompt))
return self

Expand Down
Loading

0 comments on commit f5992cf

Please sign in to comment.