Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(llm): handle 'island nodes' extraction in 2-step graph queries and add asynchronous methods to the four types of generation functions in the rag web demo. #58

Merged
merged 13 commits into from
Aug 12, 2024
Merged
92 changes: 50 additions & 42 deletions hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,12 @@ def convert_bool_str(string):
raise gr.Error(f"Invalid boolean string: {string}")


# TODO: enhance/distinguish the "graph_rag" name to avoid confusion
def graph_rag(text: str, raw_answer: str, vector_only_answer: str,
graph_only_answer: str, graph_vector_answer):
vector_search = convert_bool_str(vector_only_answer) or convert_bool_str(graph_vector_answer)
graph_search = convert_bool_str(graph_only_answer) or convert_bool_str(graph_vector_answer)

if raw_answer == "false" and not vector_search and not graph_search:
gr.Warning("Please select at least one generate mode.")
return "", "", "", ""
Expand All @@ -68,6 +70,7 @@ def graph_rag(text: str, raw_answer: str, vector_only_answer: str,
graph_only_answer=convert_bool_str(graph_only_answer),
graph_vector_answer=convert_bool_str(graph_vector_answer)
).run(verbose=True, query=text)

try:
context = searcher.run(verbose=True, query=text)
return (
Expand All @@ -76,9 +79,12 @@ def graph_rag(text: str, raw_answer: str, vector_only_answer: str,
context.get("graph_only_answer", ""),
context.get("graph_vector_answer", "")
)
except Exception as e: # pylint: disable=broad-exception-caught
except ValueError as e:
log.error(e)
raise gr.Error(str(e))
except Exception as e: # pylint: disable=broad-exception-caught
log.error(e)
raise gr.Error(f"An unexpected error occurred: {str(e)}")


def build_kg(file, schema, example_prompt, build_mode): # pylint: disable=too-many-branches
Expand All @@ -93,10 +99,11 @@ def build_kg(file, schema, example_prompt, build_mode): # pylint: disable=too-m
text += para.text
text += "\n"
elif full_path.endswith(".pdf"):
raise gr.Error("PDF will be supported later!")
raise gr.Error("PDF will be supported later! Try to upload text/docx now")
else:
raise gr.Error("Please input txt or docx file.")
builder = KgBuilder(LLMs().get_llm(), Embeddings().get_embedding(), get_hg_client())

if build_mode != "Rebuild vertex index":
if schema:
try:
Expand All @@ -108,19 +115,22 @@ def build_kg(file, schema, example_prompt, build_mode): # pylint: disable=too-m
else:
return "ERROR: please input schema."
builder.chunk_split(text, "paragraph", "zh")
if build_mode == "Rebuild vertex index":

# TODO: avoid hardcoding the "build_mode" strings (use var/constant instead)
if build_mode == "Rebuild Vector":
builder.fetch_graph_data()
else:
builder.extract_info(example_prompt, "property_graph")
if build_mode != "Test":
if build_mode in ("Clear and import", "Rebuild vertex index"):
# "Test Mode", "Import Mode", "Clear and Import", "Rebuild Vector"
if build_mode != "Test Mode":
if build_mode in ("Clear and Import", "Rebuild Vector"):
clean_vector_index()
builder.build_vector_index()
if build_mode == "Clear and import":
if build_mode == "Clear and Import":
clean_hg_data()
if build_mode in ("Clear and import", "Import"):
if build_mode in ("Clear and Import", "Import Mode"):
builder.commit_to_hugegraph()
if build_mode != "Test":
if build_mode != "Test Mode":
builder.build_vertex_id_semantic_index()
log.debug(builder.operators)
try:
Expand Down Expand Up @@ -319,22 +329,18 @@ def apply_embedding_configuration(arg1, arg2, arg3):


gr.Markdown(
"""## 1. build knowledge graph
"""## 1. Build vector/graph RAG (💡)
- Document: Input document file which should be TXT or DOCX.
- Schema: Accepts two types of text as below:
- User-defined JSON format Schema.
- Specify the name of the HugeGraph graph instance, and it will
automatically extract the schema of the graph.
- Specify the name of the HugeGraph graph instance, it will automatically get the schema from it.
- Info extract head: The head of prompt of info extracting.
- Build mode:
- Test: Only extract vertices and edges from file without building vector index or
importing into HugeGraph.
- Clear and Import: Clear the vector index and data of HugeGraph and then extract and
import new data.
- Import: Extract the data and append it to HugeGraph and vector index without clearing
anything.
- Rebuild vertex index: Do not clear the HugeGraph data, but only clear vector index
and build new one.
- Test Mode: Only extract vertices and edges from the file into memory (without building the vector index or
writing data into HugeGraph)
- Import Mode: Extract the data and append it to HugeGraph & the vector index (without clearing any existing data)
- Clear and Import: Clear all existed RAG data(vector + graph), then rebuild them from the current input
- Rebuild Vector: Only rebuild vector index. (keep the graph data intact)
"""
)

Expand Down Expand Up @@ -380,10 +386,9 @@ def apply_embedding_configuration(arg1, arg2, arg3):
info_extract_template = gr.Textbox(value=SCHEMA_EXAMPLE_PROMPT,
label="Info extract head")
with gr.Column():
mode = gr.Radio(choices=["Test", "Clear and import", "Import",
"Rebuild vertex index"],
value="Test", label="Build mode")
btn = gr.Button("Build knowledge graph")
mode = gr.Radio(choices=["Test Mode", "Import Mode", "Clear and Import", "Rebuild Vector"],
value="Test Mode", label="Build mode")
btn = gr.Button("Build Vector/Graph RAG")
with gr.Row():
out = gr.Textbox(label="Output", show_copy_button=True)
btn.click( # pylint: disable=no-member
Expand All @@ -392,40 +397,43 @@ def apply_embedding_configuration(arg1, arg2, arg3):
outputs=out
)

gr.Markdown("""## 2. Retrieval augmented generation by hugegraph""")
gr.Markdown("""## 2. RAG with HugeGraph 📖""")
with gr.Row():
with gr.Column(scale=2):
inp = gr.Textbox(value="Tell me about Sarah.", label="Question")
raw_out = gr.Textbox(label="Raw LLM Answer", show_copy_button=True)
vector_only_out = gr.Textbox(label="Vector-only answer", show_copy_button=True)
graph_only_out = gr.Textbox(label="Graph-only answer", show_copy_button=True)
graph_vector_out = gr.Textbox(label="Graph-Vector answer", show_copy_button=True)
inp = gr.Textbox(value="Tell me about Sarah.", label="Question", show_copy_button=True)
raw_out = gr.Textbox(label="Basic LLM Answer", show_copy_button=True)
vector_only_out = gr.Textbox(label="Vector-only Answer", show_copy_button=True)
graph_only_out = gr.Textbox(label="Graph-only Answer", show_copy_button=True)
graph_vector_out = gr.Textbox(label="Graph-Vector Answer", show_copy_button=True)
with gr.Column(scale=1):
raw_radio = gr.Radio(choices=["true", "false"], value="false",
label="Raw LLM answer")
label="Basic LLM Answer")
vector_only_radio = gr.Radio(choices=["true", "false"], value="true",
label="Vector-only answer")
label="Vector-only Answer")
graph_only_radio = gr.Radio(choices=["true", "false"], value="false",
label="Graph-only answer")
label="Graph-only Answer")
graph_vector_radio = gr.Radio(choices=["true", "false"], value="false",
label="Graph-Vector answer")
btn = gr.Button("Retrieval augmented generation")
btn.click(fn=graph_rag, inputs=[inp, raw_radio, vector_only_radio, graph_only_radio, # pylint: disable=no-member
label="Graph-Vector Answer")
btn = gr.Button("Answer Question")
btn.click(fn=graph_rag, inputs=[inp, raw_radio, vector_only_radio, graph_only_radio, # pylint: disable=no-member
graph_vector_radio],
outputs=[raw_out, vector_only_out, graph_only_out, graph_vector_out])

gr.Markdown("""## 3. Others """)
gr.Markdown("""## 3. Others (🚧) """)
with gr.Row():
inp = []
with gr.Column():
inp = gr.Textbox(value="g.V().limit(10)", label="Gremlin query", show_copy_button=True)
format = gr.Checkbox(label="Format JSON", value=True)
out = gr.Textbox(label="Output", show_copy_button=True)
btn = gr.Button("Initialize HugeGraph test data")
btn.click(fn=init_hg_test_data, inputs=inp, outputs=out) # pylint: disable=no-member
btn = gr.Button("Run gremlin query on HugeGraph")
btn.click(fn=run_gremlin_query, inputs=[inp, format], outputs=out) # pylint: disable=no-member

with gr.Row():
inp = gr.Textbox(value="g.V().limit(10)", label="Gremlin query")
inp = []
out = gr.Textbox(label="Output", show_copy_button=True)
btn = gr.Button("Run gremlin query on HugeGraph")
btn.click(fn=run_gremlin_query, inputs=inp, outputs=out) # pylint: disable=no-member
btn = gr.Button("(BETA) Init HugeGraph test data (🚧WIP)")
btn.click(fn=init_hg_test_data, inputs=inp, outputs=out) # pylint: disable=no-member

app = gr.mount_gradio_app(app, hugegraph_llm, path="/")
# Note: set reload to False in production environment
uvicorn.run(app, host=args.host, port=args.port)
Expand Down
1 change: 1 addition & 0 deletions hugegraph-llm/src/hugegraph_llm/enums/id_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from enum import Enum


# Note: we don't support the "UUID" strategy for now
class IdStrategy(Enum):
AUTOMATIC = "AUTOMATIC"
CUSTOMIZE_NUMBER = "CUSTOMIZE_NUMBER"
Expand Down
4 changes: 2 additions & 2 deletions hugegraph-llm/src/hugegraph_llm/indices/graph_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@ def __init__(
def clear_graph(self):
self.client.gremlin().exec("g.V().drop()")

# TODO: replace triples with a more specific graph element type & implement it
def add_triples(self, triples: list):
# TODO
pass

# TODO: replace triples with a more specific graph element type & implement it
def search_triples(self, max_deep: int = 2):
# TODO
pass

def execute_gremlin_query(self, query: str):
Expand Down
7 changes: 7 additions & 0 deletions hugegraph-llm/src/hugegraph_llm/models/embeddings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,13 @@ def get_text_embedding(
) -> List[float]:
"""Comment"""

@abstractmethod
async def async_get_text_embedding(
self,
text: str
) -> List[float]:
"""Comment"""

@staticmethod
def similarity(
embedding1: Union[List[float], np.ndarray],
Expand Down
9 changes: 9 additions & 0 deletions hugegraph-llm/src/hugegraph_llm/models/embeddings/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(
):
self.model = model
self.client = ollama.Client(host=f"http://{host}:{port}", **kwargs)
self.async_client = ollama.AsyncClient(host=f"http://{host}:{port}", **kwargs)
self.embedding_dimension = None

def get_text_embedding(
Expand All @@ -40,3 +41,11 @@ def get_text_embedding(
) -> List[float]:
"""Comment"""
return list(self.client.embeddings(model=self.model, prompt=text)["embedding"])

async def async_get_text_embedding(
self,
text: str
) -> List[float]:
"""Comment"""
response = await self.async_client.embeddings(model=self.model, prompt=text)
return list(response["embedding"])
5 changes: 5 additions & 0 deletions hugegraph-llm/src/hugegraph_llm/models/embeddings/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,8 @@ def get_text_embedding(self, text: str) -> List[float]:
"""Comment"""
response = self.client.create(input=text, model=self.embedding_model_name)
return response.data[0].embedding

async def async_get_text_embedding(self, text: str) -> List[float]:
"""Comment"""
response = await self.client.acreate(input=text, model=self.embedding_model_name)
return response.data[0].embedding
8 changes: 8 additions & 0 deletions hugegraph-llm/src/hugegraph_llm/models/embeddings/qianfan.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,11 @@ def get_text_embedding(self, text: str) -> List[float]:
texts=[text]
)
return response["body"]["data"][0]["embedding"]

async def async_get_text_embedding(self, text: str) -> List[float]:
""" Usage refer: https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlmokk9qn"""
response = await self.client.ado(
model=self.embedding_model_name,
texts=[text]
)
return response["body"]["data"][0]["embedding"]
8 changes: 8 additions & 0 deletions hugegraph-llm/src/hugegraph_llm/models/llms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,14 @@ def generate(
) -> str:
"""Comment"""

@abstractmethod
async def agenerate(
self,
messages: Optional[List[Dict[str, Any]]] = None,
prompt: Optional[str] = None,
) -> str:
"""Comment"""

@abstractmethod
def generate_streaming(
self,
Expand Down
21 changes: 21 additions & 0 deletions hugegraph-llm/src/hugegraph_llm/models/llms/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class OllamaClient(BaseLLM):
def __init__(self, model: str, host: str = "127.0.0.1", port: int = 11434, **kwargs):
self.model = model
self.client = ollama.Client(host=f"http://{host}:{port}", **kwargs)
self.async_client = ollama.AsyncClient(host=f"http://{host}:{port}", **kwargs)

@retry(tries=3, delay=1)
def generate(
Expand All @@ -50,6 +51,26 @@ def generate(
print(f"Retrying LLM call {e}")
raise e

@retry(tries=3, delay=1)
async def agenerate(
self,
messages: Optional[List[Dict[str, Any]]] = None,
prompt: Optional[str] = None,
) -> str:
"""Comment"""
if messages is None:
assert prompt is not None, "Messages or prompt must be provided."
messages = [{"role": "user", "content": prompt}]
try:
response = await self.async_client.chat(
model=self.model,
messages=messages,
)
return response["message"]["content"]
except Exception as e:
print(f"Retrying LLM call {e}")
raise e

def generate_streaming(
self,
messages: Optional[List[Dict[str, Any]]] = None,
Expand Down
30 changes: 30 additions & 0 deletions hugegraph-llm/src/hugegraph_llm/models/llms/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,36 @@ def generate(
log.error("Retrying LLM call %s", e)
raise e

@retry(tries=3, delay=1)
async def agenerate(
self,
messages: Optional[List[Dict[str, Any]]] = None,
prompt: Optional[str] = None,
) -> str:
"""Generate a response to the query messages/prompt."""
if messages is None:
assert prompt is not None, "Messages or prompt must be provided."
messages = [{"role": "user", "content": prompt}]
try:
completions = await openai.ChatCompletion.acreate(
model=self.model,
temperature=self.temperature,
max_tokens=self.max_tokens,
messages=messages,
)
return completions.choices[0].message.content
# catch context length / do not retry
except openai.error.InvalidRequestError as e:
log.critical("Fatal: %s", e)
return str(f"Error: {e}")
# catch authorization errors / do not retry
except openai.error.AuthenticationError:
log.critical("The provided OpenAI API key is invalid")
return "Error: The provided OpenAI API key is invalid"
except Exception as e:
log.error("Retrying LLM call %s", e)
raise e

def generate_streaming(
self,
messages: Optional[List[Dict[str, Any]]] = None,
Expand Down
17 changes: 17 additions & 0 deletions hugegraph-llm/src/hugegraph_llm/models/llms/qianfan.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,23 @@ def generate(
)
return response.body["result"]

@retry(tries=3, delay=1)
async def agenerate(
self,
messages: Optional[List[Dict[str, Any]]] = None,
prompt: Optional[str] = None,
) -> str:
if messages is None:
assert prompt is not None, "Messages or prompt must be provided."
messages = [{"role": "user", "content": prompt}]

response = await self.chat_comp.ado(model=self.chat_model, messages=messages)
if response.code != 200:
raise Exception(
f"Request failed with code {response.code}, message: {response.body['error_msg']}"
)
return response.body["result"]

def generate_streaming(
self,
messages: Optional[List[Dict[str, Any]]] = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ def run(self, schema=None) -> Any: # pylint: disable=too-many-branches
raise ValueError("Input data is not a dictionary.")
if "vertexlabels" not in schema or "edgelabels" not in schema:
raise ValueError("Input data does not contain 'vertexlabels' or 'edgelabels'.")
if not isinstance(schema["vertexlabels"], list) or not isinstance(schema["edgelabels"], list):
if not isinstance(schema["vertexlabels"], list) or not isinstance(schema["edgelabels"],
list):
raise ValueError("'vertexlabels' or 'edgelabels' in input data is not a list.")
for vertex in schema["vertexlabels"]:
if not isinstance(vertex, dict):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(self, embedding: BaseEmbedding, topk: int = 10):
self.topk = topk

def run(self, context: Dict[str, Any]) -> Dict[str, Any]:
# TODO: exact > fuzzy; vertex > 1-depth-neighbour > 2-depth-neighbour; priority vertices
query = context.get("query")

vector_result = context.get("vector_result", [])
Expand Down
Loading
Loading