Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(llm): remove enable_gql logic in api & rag block #148

Merged
merged 4 commits into from
Dec 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ class RAGRequest(BaseModel):
graph_ratio: float = Query(0.5, description="The ratio of GraphRAG ans & vector ans")
rerank_method: Literal["bleu", "reranker"] = Query("bleu", description="Method to rerank the results.")
near_neighbor_first: bool = Query(False, description="Prioritize near neighbors in the search results.")
with_gremlin_tmpl: bool = Query(True, description="Use example template in text2gremlin")
custom_priority_info: str = Query("", description="Custom information to prioritize certain results.")
answer_prompt: Optional[str] = Query(prompt.answer_prompt, description="Prompt to guide the answer generation.")
keywords_extract_prompt: Optional[str] = Query(
Expand All @@ -49,8 +48,9 @@ class RAGRequest(BaseModel):
# TODO: import the default value of prompt.* dynamically
class GraphRAGRequest(BaseModel):
query: str = Query("", description="Query you want to ask")
gremlin_tmpl_num: int = Query(1, description="Number of Gremlin templates to use.")
with_gremlin_tmpl: bool = Query(True, description="Use example template in text2gremlin")
gremlin_tmpl_num: int = Query(
1, description="Number of Gremlin templates to use. If num <=0 means template is not provided"
)
rerank_method: Literal["bleu", "reranker"] = Query("bleu", description="Method to rerank the results.")
near_neighbor_first: bool = Query(False, description="Prioritize near neighbors in the search results.")
custom_priority_info: str = Query("", description="Custom information to prioritize certain results.")
Expand Down
10 changes: 5 additions & 5 deletions hugegraph-llm/src/hugegraph_llm/api/rag_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ def rag_answer_api(req: RAGRequest):
vector_only_answer=req.vector_only,
graph_only_answer=req.graph_only,
graph_vector_answer=req.graph_vector_answer,
with_gremlin_template=req.with_gremlin_tmpl,
graph_ratio=req.graph_ratio,
rerank_method=req.rerank_method,
near_neighbor_first=req.near_neighbor_first,
Expand All @@ -62,9 +61,11 @@ def rag_answer_api(req: RAGRequest):
# TODO: we need more info in the response for users to understand the query logic
return {
"query": req.query,
**{key: value
for key, value in zip(["raw_answer", "vector_only", "graph_only", "graph_vector_answer"], result)
if getattr(req, key)}
**{
key: value
for key, value in zip(["raw_answer", "vector_only", "graph_only", "graph_vector_answer"], result)
if getattr(req, key)
},
}

@router.post("/rag/graph", status_code=status.HTTP_200_OK)
Expand All @@ -73,7 +74,6 @@ def graph_rag_recall_api(req: GraphRAGRequest):
result = graph_rag_recall_func(
query=req.query,
gremlin_tmpl_num=req.gremlin_tmpl_num,
with_gremlin_tmpl=req.with_gremlin_tmpl,
rerank_method=req.rerank_method,
near_neighbor_first=req.near_neighbor_first,
custom_related_information=req.custom_priority_info,
Expand Down
15 changes: 6 additions & 9 deletions hugegraph-llm/src/hugegraph_llm/demo/rag_demo/rag_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def rag_answer(
vector_only_answer: bool,
graph_only_answer: bool,
graph_vector_answer: bool,
with_gremlin_template: bool,
graph_ratio: float,
rerank_method: Literal["bleu", "reranker"],
near_neighbor_first: bool,
Expand Down Expand Up @@ -80,7 +79,6 @@ def rag_answer(
rag.extract_keywords(extract_template=keywords_extract_prompt).keywords_to_vid().import_schema(
huge_settings.graph_name
).query_graphdb(
with_gremlin_template=with_gremlin_template,
num_gremlin_generate_example=gremlin_tmpl_num,
gremlin_prompt=gremlin_prompt,
)
Expand Down Expand Up @@ -125,7 +123,10 @@ def create_rag_block():
value=prompt.answer_prompt, label="Query Prompt", show_copy_button=True, lines=7
)
keywords_extract_prompt_input = gr.Textbox(
value=prompt.keywords_extract_prompt, label="Keywords Extraction Prompt", show_copy_button=True, lines=7
value=prompt.keywords_extract_prompt,
label="Keywords Extraction Prompt",
show_copy_button=True,
lines=7,
)
with gr.Column(scale=1):
with gr.Row():
Expand All @@ -134,8 +135,6 @@ def create_rag_block():
with gr.Row():
graph_only_radio = gr.Radio(choices=[True, False], value=True, label="Graph-only Answer")
graph_vector_radio = gr.Radio(choices=[True, False], value=False, label="Graph-Vector Answer")
with gr.Row():
Copy link
Member

@imbajin imbajin Dec 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems we should change it to a number radio? (instead of removing it directly?)

with_gremlin_template_radio = gr.Radio(choices=[True, False], value=True, label="With Gremlin Template")

def toggle_slider(enable):
return gr.update(interactive=enable)
Expand All @@ -148,6 +147,7 @@ def toggle_slider(enable):
value="reranker" if online_rerank else "bleu",
label="Rerank method",
)
example_num = gr.Number(value=2, label="Template Num (0 to disable it) ", precision=0)
graph_ratio = gr.Slider(0, 1, 0.6, label="Graph Ratio", step=0.1, interactive=False)

graph_vector_radio.change(
Expand All @@ -172,13 +172,13 @@ def toggle_slider(enable):
vector_only_radio,
graph_only_radio,
graph_vector_radio,
with_gremlin_template_radio,
graph_ratio,
rerank_method,
near_neighbor_first,
custom_related_information,
answer_prompt_input,
keywords_extract_prompt_input,
example_num,
],
outputs=[raw_out, vector_only_out, graph_only_out, graph_vector_out],
)
Expand Down Expand Up @@ -237,7 +237,6 @@ def several_rag_answer(
graph_ratio: float,
rerank_method: Literal["bleu", "reranker"],
near_neighbor_first: bool,
with_gremlin_template: bool,
custom_related_information: str,
answer_prompt: str,
keywords_extract_prompt: str,
Expand All @@ -257,7 +256,6 @@ def several_rag_answer(
graph_ratio,
rerank_method,
near_neighbor_first,
with_gremlin_template,
custom_related_information,
answer_prompt,
keywords_extract_prompt,
Expand Down Expand Up @@ -291,7 +289,6 @@ def several_rag_answer(
graph_ratio,
rerank_method,
near_neighbor_first,
with_gremlin_template_radio,
custom_related_information,
answer_prompt_input,
keywords_extract_prompt_input,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@

def store_schema(schema, question, gremlin_prompt):
if (
prompt.text2gql_graph_schema != schema or
prompt.default_question != question or
prompt.gremlin_generate_prompt != gremlin_prompt
prompt.text2gql_graph_schema != schema
or prompt.default_question != question
or prompt.gremlin_generate_prompt != gremlin_prompt
):
prompt.text2gql_graph_schema = schema
prompt.default_question = question
Expand Down Expand Up @@ -90,7 +90,8 @@ def gremlin_generate(
updated_schema = sm.simple_schema(schema) if short_schema else schema
store_schema(str(updated_schema), inp, gremlin_prompt)
context = (
generator.example_index_query(example_num).gremlin_generate_synthesize(updated_schema, gremlin_prompt)
generator.example_index_query(example_num)
.gremlin_generate_synthesize(updated_schema, gremlin_prompt)
.run(query=inp)
)
try:
Expand Down Expand Up @@ -183,7 +184,6 @@ def create_text2gremlin_block() -> Tuple:
def graph_rag_recall(
query: str,
gremlin_tmpl_num: int,
with_gremlin_tmpl: bool,
rerank_method: Literal["bleu", "reranker"],
near_neighbor_first: bool,
custom_related_information: str,
Expand All @@ -193,7 +193,6 @@ def graph_rag_recall(
rag = RAGPipeline()

rag.extract_keywords().keywords_to_vid().import_schema(huge_settings.graph_name).query_graphdb(
with_gremlin_template=with_gremlin_tmpl,
num_gremlin_generate_example=gremlin_tmpl_num,
gremlin_prompt=gremlin_prompt,
).merge_dedup_rerank(
Expand Down
4 changes: 1 addition & 3 deletions hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,7 @@ def query_graphdb(
max_v_prop_len: int = 2048,
max_e_prop_len: int = 256,
prop_to_match: Optional[str] = None,
with_gremlin_template: bool = True,
num_gremlin_generate_example: int = 1,
num_gremlin_generate_example: Optional[int] = 1,
gremlin_prompt: Optional[str] = prompt.gremlin_generate_prompt,
):
"""
Expand All @@ -146,7 +145,6 @@ def query_graphdb(
max_v_prop_len=max_v_prop_len,
max_e_prop_len=max_e_prop_len,
prop_to_match=prop_to_match,
with_gremlin_template=with_gremlin_template,
num_gremlin_generate_example=num_gremlin_generate_example,
gremlin_prompt=gremlin_prompt,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,9 @@ def __init__(
prop_to_match: Optional[str] = None,
llm: Optional[BaseLLM] = None,
embedding: Optional[BaseEmbedding] = None,
max_v_prop_len: int = 2048,
max_e_prop_len: int = 256,
with_gremlin_template: bool = True,
num_gremlin_generate_example: int = 1,
max_v_prop_len: Optional[int] = 2048,
max_e_prop_len: Optional[int] = 256,
num_gremlin_generate_example: Optional[int] = 1,
gremlin_prompt: Optional[str] = None,
):
self._client = PyHugeClient(
Expand All @@ -108,7 +107,6 @@ def __init__(
embedding=embedding,
)
self._num_gremlin_generate_example = num_gremlin_generate_example
self._with_gremlin_template = with_gremlin_template
self._gremlin_prompt = gremlin_prompt or prompt.gremlin_generate_prompt

def run(self, context: Dict[str, Any]) -> Dict[str, Any]:
Expand Down Expand Up @@ -138,7 +136,7 @@ def _gremlin_generate_query(self, context: Dict[str, Any]) -> Dict[str, Any]:
gremlin_response = self._gremlin_generator.gremlin_generate_synthesize(
context["simple_schema"], vertices=vertices, gremlin_prompt=self._gremlin_prompt
).run(query=query, query_embedding=query_embedding)
if self._with_gremlin_template:
if self._num_gremlin_generate_example > 0:
gremlin = gremlin_response["result"]
else:
gremlin = gremlin_response["raw_result"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,15 @@ def __init__(self, embedding: BaseEmbedding = None, num_examples: int = 1):
self.vector_index = VectorIndex.from_index_file(self.index_dir)

def _ensure_index_exists(self):
if not (os.path.exists(os.path.join(self.index_dir, "index.faiss"))
and os.path.exists(os.path.join(self.index_dir, "properties.pkl"))):
if not (
os.path.exists(os.path.join(self.index_dir, "index.faiss"))
and os.path.exists(os.path.join(self.index_dir, "properties.pkl"))
):
log.warning("No gremlin example index found, will generate one.")
self._build_default_example_index()

def _get_match_result(self, context: Dict[str, Any], query: str) -> List[Dict[str, Any]]:
if self.num_examples == 0:
if self.num_examples <= 0:
return []

query_embedding = context.get("query_embedding")
Expand All @@ -53,8 +55,7 @@ def _get_match_result(self, context: Dict[str, Any], query: str) -> List[Dict[st
return self.vector_index.search(query_embedding, self.num_examples, dis_threshold=1.8)

def _build_default_example_index(self):
properties = pd.read_csv(os.path.join(resource_path, "demo",
"text2gremlin.csv")).to_dict(orient="records")
properties = pd.read_csv(os.path.join(resource_path, "demo", "text2gremlin.csv")).to_dict(orient="records")
embeddings = [self.embedding.get_text_embedding(row["query"]) for row in tqdm(properties)]
vector_index = VectorIndex(len(embeddings[0]))
vector_index.add(embeddings, properties)
Expand Down
Loading