Skip to content

Commit

Permalink
fix(llm): fix parametername in docs
Browse files Browse the repository at this point in the history
  • Loading branch information
HJ-Young committed Dec 18, 2024
1 parent f5992cf commit 19db15b
Showing 1 changed file with 45 additions and 26 deletions.
71 changes: 45 additions & 26 deletions hugegraph-llm/src/hugegraph_llm/api/rag_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import json
from typing import Literal, Optional

from fastapi import status, APIRouter, HTTPException, Query
from fastapi import status, APIRouter, HTTPException, Query, Body

from hugegraph_llm.api.exceptions.rag_exceptions import generate_response
from hugegraph_llm.api.models.rag_requests import (
Expand Down Expand Up @@ -64,26 +64,34 @@ def rag_http_api(
):
@router.post("/rag", status_code=status.HTTP_200_OK)
def rag_answer_api(
req: RAGRequest,
query_param: str = Query("", description="Query you want to ask"),
raw_answer_param: bool = Query(False, description="Use LLM to generate answer directly"),
vector_only_param: bool = Query(False, description="Use LLM to generate answer with vector"),
graph_only_param: bool = Query(False, description="Use LLM to generate answer with graph RAG only"),
graph_vector_answer_param: bool = Query(False, description="Use LLM to generate answer with vector & GraphRAG"),
with_gremlin_tmpl_param: bool = Query(True, description="Use exapmle template in text2gremlin"),
graph_ratio_param: float = Query(0.5, description="The ratio of GraphRAG ans & vector ans"),
rerank_method_param: Literal["bleu", "reranker"] = Query("bleu", description="Method to rerank the results."),
near_neighbor_first_param: bool = Query(False, description="Prioritize near neighbors in the search results."),
custom_priority_info_param: str = Query("", description="Custom information to prioritize certain results."),
answer_prompt_param: Optional[str] = Query(
req: RAGRequest = Body(...),
query: Optional[str] = Query("", description="Query you want to ask"),
raw_answer: Optional[bool] = Query(False, description="Use LLM to generate answer directly"),
vector_only: Optional[bool] = Query(False, description="Use LLM to generate answer with vector"),
graph_only: Optional[bool] = Query(False, description="Use LLM to generate answer with graph RAG only"),
graph_vector_answer: Optional[bool] = Query(
True, description="Use LLM to generate answer with vector & GraphRAG"
),
with_gremlin_tmpl: Optional[bool] = Query(True, description="Use example template in text2gremlin"),
graph_ratio: Optional[float] = Query(0.5, description="The ratio of GraphRAG ans & vector ans"),
rerank_method: Optional[Literal["bleu", "reranker"]] = Query(
"bleu", description="Method to rerank the results."
),
near_neighbor_first: Optional[bool] = Query(
True, description="Prioritize near neighbors in the search results."
),
custom_priority_info: Optional[str] = Query(
"", description="Custom information to prioritize certain results."
),
answer_prompt: Optional[str] = Query(
prompt.answer_prompt, description="Prompt to guide the answer generation."
),
keywords_extract_prompt_param: Optional[str] = Query(
prompt.keywords_extract_prompt, description="Prompt for extracting keywords from the query."
keywords_extract_prompt: Optional[str] = Query(
prompt.keywords_extract_prompt, description="Prompt for extracting keywords from query."
),
gremlin_tmpl_num_param: int = Query(1, description="Number of Gremlin templates to use."),
gremlin_prompt_param: Optional[str] = Query(
prompt.gremlin_generate_prompt, description="Prompt for the Gremlin query. Don't change it casually"
gremlin_tmpl_num: Optional[int] = Query(1, description="Number of Gremlin templates to use."),
gremlin_prompt: Optional[str] = Query(
prompt.gremlin_generate_prompt, description="Prompt for the Gremlin query."
),
):
result = rag_answer_func(
Expand All @@ -107,20 +115,31 @@ def rag_answer_api(
for key, value in zip(["raw_answer", "vector_only", "graph_only", "graph_vector_answer"], result)
if getattr(req, key)
}
return {
key: value
for key, value in zip(["raw_answer", "vector_only", "graph_only", "graph_vector_answer"], result)
if getattr(req, key)
}

@router.post("/rag/graph", status_code=status.HTTP_200_OK)
def graph_rag_recall_api(
req: GraphRAGRequest,
query_param: str = Query("", description="Query you want to ask"),
with_gremlin_templ_param: bool = Query(True, description="Use exapmle template in text2gremlin"),
rerank_method_param: Literal["bleu", "reranker"] = Query("bleu", description="Method to rerank the results."),
near_neighbor_first_param: bool = Query(False, description="Prioritize near neighbors in the search results."),
custom_priority_info_param: str = Query("", description="Custom information to prioritize certain results."),
answer_prompt_param: Optional[str] = Query(
query: Optional[str] = Query("", description="Query you want to ask"),
with_gremlin_tmpl: Optional[bool] = Query(True, description="Use exapmle template in text2gremlin"),
rerank_method: Optional[Literal["bleu", "reranker"]] = Query(
"bleu", description="Method to rerank the results."
),
near_neighbor_first: Optional[bool] = Query(
False, description="Prioritize near neighbors in the search results."
),
custom_priority_info: Optional[str] = Query(
"", description="Custom information to prioritize certain results."
),
answer_prompt: Optional[str] = Query(
prompt.answer_prompt, description="Prompt to guide the answer generation."
),
gremlin_tmpl_num_param: int = Query(1, description="Number of Gremlin templates to use."),
gremlin_prompt_param: Optional[str] = Query(
gremlin_tmpl_num: Optional[int] = Query(1, description="Number of Gremlin templates to use."),
gremlin_prompt: Optional[str] = Query(
prompt.gremlin_generate_prompt, description="Prompt for the Gremlin query. Don't change it casually"
),
):
Expand Down

0 comments on commit 19db15b

Please sign in to comment.