Skip to content

Commit

Permalink
clean the func & tiny fix
Browse files Browse the repository at this point in the history
  • Loading branch information
imbajin committed Aug 17, 2024
1 parent 5a0dff6 commit 07e552d
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 53 deletions.
16 changes: 0 additions & 16 deletions hugegraph-llm/src/hugegraph_llm/api/__init__ .py

This file was deleted.

18 changes: 9 additions & 9 deletions hugegraph-llm/src/hugegraph_llm/api/rag_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,7 @@ class LLMConfigRequest(BaseModel):
port: str = None


def rag_web_http_api(app: FastAPI, graph_rag_func, apply_graph_configuration_func,
apply_llm_configuration_func, apply_embedding_configuration_func):
def rag_http_api(app: FastAPI, graph_rag_func, apply_graph_conf, apply_llm_conf, apply_embedding_conf):
@app.post("/rag")
def graph_rag_api(req: RAGRequest):
result = graph_rag_func(req.query, req.raw_llm, req.vector_only, req.graph_only, req.graph_vector)
Expand All @@ -65,7 +64,7 @@ def graph_rag_api(req: RAGRequest):
@app.post("/graph/config")
def graph_config_api(req: GraphConfigRequest):
# Accept status code
status_code = apply_graph_configuration_func(req.ip, req.port, req.name, req.user, req.pwd, req.gs, origin_call="http")
status_code = apply_graph_conf(req.ip, req.port, req.name, req.user, req.pwd, req.gs, origin_call="http")

if status_code == -1:
return {"message": "Unsupported HTTP method"}
Expand All @@ -80,11 +79,12 @@ def llm_config_api(req: LLMConfigRequest):
settings.llm_type = req.llm_type

if req.llm_type == "openai":
status_code = apply_llm_configuration_func(req.api_key, req.api_base, req.language_model, req.max_tokens, origin_call="http")
status_code = apply_llm_conf(req.api_key, req.api_base, req.language_model, req.max_tokens,
origin_call="http")
elif req.llm_type == "qianfan_wenxin":
status_code = apply_llm_configuration_func(req.api_key, req.secret_key, req.language_model, None, origin_call="http")
status_code = apply_llm_conf(req.api_key, req.secret_key, req.language_model, None, origin_call="http")
else:
status_code = apply_llm_configuration_func(req.host, req.port, req.language_model, None, origin_call="http")
status_code = apply_llm_conf(req.host, req.port, req.language_model, None, origin_call="http")

if status_code == -1:
return {"message": "Unsupported HTTP method"}
Expand All @@ -99,11 +99,11 @@ def embedding_config_api(req: LLMConfigRequest):
settings.embedding_type = req.llm_type

if req.llm_type == "openai":
status_code = apply_embedding_configuration_func(req.api_key, req.api_base, req.language_model, origin_call="http")
status_code = apply_embedding_conf(req.api_key, req.api_base, req.language_model, origin_call="http")
elif req.llm_type == "qianfan_wenxin":
status_code = apply_embedding_configuration_func(req.api_key, req.api_base, None, origin_call="http")
status_code = apply_embedding_conf(req.api_key, req.api_base, None, origin_call="http")
else:
status_code = apply_embedding_configuration_func(req.host, req.port, req.language_model, origin_call="http")
status_code = apply_embedding_conf(req.host, req.port, req.language_model, origin_call="http")

if status_code == -1:
return {"message": "Unsupported HTTP method"}
Expand Down
48 changes: 20 additions & 28 deletions hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,20 +41,11 @@
from hugegraph_llm.utils.log import log
from hugegraph_llm.utils.hugegraph_utils import get_hg_client
from hugegraph_llm.utils.vector_index_utils import clean_vector_index
from hugegraph_llm.api.rag_api import rag_web_http_api
from hugegraph_llm.api.rag_api import rag_http_api


def convert_bool_str(string):
if string == "true":
return True
if string == "false":
return False
raise gr.Error(f"Invalid boolean string: {string}")


# TODO: enhance/distinguish the "graph_rag" name to avoid confusion
def graph_rag(text: str, raw_answer: bool, vector_only_answer: bool,
graph_only_answer: bool, graph_vector_answer: bool):
def rag_answer(text: str, raw_answer: bool, vector_only_answer: bool,
graph_only_answer: bool, graph_vector_answer: bool) -> tuple:
vector_search = vector_only_answer or graph_vector_answer
graph_search = graph_only_answer or graph_vector_answer

Expand Down Expand Up @@ -89,7 +80,7 @@ def graph_rag(text: str, raw_answer: bool, vector_only_answer: bool,
raise gr.Error(f"An unexpected error occurred: {str(e)}")


def build_kg(file, schema, example_prompt, build_mode): # pylint: disable=too-many-branches
def build_kg(file, schema, example_prompt, build_mode) -> str: # pylint: disable=too-many-branches
full_path = file.name
if full_path.endswith(".txt"):
with open(full_path, "r", encoding="utf-8") as f:
Expand Down Expand Up @@ -142,8 +133,9 @@ def build_kg(file, schema, example_prompt, build_mode): # pylint: disable=too-m
log.error(e)
raise gr.Error(str(e))


# todo: origin_call was created to stave off problems with gr.error that needed to be fixed
def test_api_connection(url, method="GET", headers=None, body=None, auth=None, origin_call=None):
def test_api_connection(url, method="GET", headers=None, body=None, auth=None, origin_call=None) -> int:
# TODO: use fastapi.request / starlette instead? (Also add a try-catch here)
response = None
log.debug("Request URL: %s", url)
Expand Down Expand Up @@ -172,7 +164,7 @@ def test_api_connection(url, method="GET", headers=None, body=None, auth=None, o
return response.status_code


def apply_embedding_configuration(arg1, arg2, arg3, origin_call=None):
def apply_embedding_config(arg1, arg2, arg3, origin_call=None) -> int:
# Because of ollama, the qianfan_wenxin model is missing the test connect procedure,
# so it defaults to 200 so that there is no return value problem
status_code = 200
Expand All @@ -184,19 +176,19 @@ def apply_embedding_configuration(arg1, arg2, arg3, origin_call=None):
test_url = settings.openai_api_base + "/models"
headers = {"Authorization": f"Bearer {arg1}"}
status_code = test_api_connection(test_url, headers=headers, origin_call=origin_call)
elif embedding_option == "qianfan_wenxin":
settings.qianfan_access_token = arg1
settings.qianfan_embed_url = arg2
elif embedding_option == "ollama":
settings.ollama_host = arg1
settings.ollama_port = int(arg2)
settings.ollama_embedding_model = arg3
elif embedding_option == "qianfan_wenxin":
settings.qianfan_access_token = arg1
settings.qianfan_embed_url = arg2
settings.update_env()
gr.Info("Configured!")
return status_code


def apply_graph_configuration(ip, port, name, user, pwd, gs, origin_call=None):
def apply_graph_config(ip, port, name, user, pwd, gs, origin_call=None) -> int:
settings.graph_ip = ip
settings.graph_port = int(port)
settings.graph_name = name
Expand All @@ -217,7 +209,7 @@ def apply_graph_configuration(ip, port, name, user, pwd, gs, origin_call=None):

# Different llm models have different parameters,
# so no meaningful argument names are given here
def apply_llm_configuration(arg1, arg2, arg3, arg4, origin_call=None):
def apply_llm_config(arg1, arg2, arg3, arg4, origin_call=None) -> int:
llm_option = settings.llm_type
# Because of ollama, the qianfan_wenxin model is missing the test connect procedure,
# so it defaults to 200 so that there is no return value problem
Expand Down Expand Up @@ -245,7 +237,7 @@ def apply_llm_configuration(arg1, arg2, arg3, arg4, origin_call=None):
return status_code


def create_hugegraph_llm_interface():
def init_rag_ui() -> gr.Interface:
with gr.Blocks() as hugegraph_llm:
gr.Markdown(
"""# HugeGraph LLM RAG Demo
Expand All @@ -264,7 +256,7 @@ def create_hugegraph_llm_interface():
]
graph_config_button = gr.Button("apply configuration")

graph_config_button.click(apply_graph_configuration, inputs=graph_config_input) # pylint: disable=no-member
graph_config_button.click(apply_graph_config, inputs=graph_config_input) # pylint: disable=no-member

gr.Markdown("2. Set up the LLM.")
llm_dropdown = gr.Dropdown(
Expand Down Expand Up @@ -307,7 +299,7 @@ def llm_settings(llm_type):
llm_config_input = []
llm_config_button = gr.Button("apply configuration")

llm_config_button.click(apply_llm_configuration, inputs=llm_config_input) # pylint: disable=no-member
llm_config_button.click(apply_llm_config, inputs=llm_config_input) # pylint: disable=no-member

gr.Markdown("3. Set up the Embedding.")
embedding_dropdown = gr.Dropdown(
Expand Down Expand Up @@ -349,11 +341,11 @@ def embedding_settings(embedding_type):

# Call the separate apply_embedding_configuration function here
embedding_config_button.click(
lambda arg1, arg2, arg3: apply_embedding_configuration(arg1, arg2, arg3),
lambda arg1, arg2, arg3: apply_embedding_config(arg1, arg2, arg3),
inputs=embedding_config_input
)

embedding_config_button.click(apply_embedding_configuration, # pylint: disable=no-member
embedding_config_button.click(apply_embedding_config, # pylint: disable=no-member
inputs=embedding_config_input)

gr.Markdown(
Expand Down Expand Up @@ -443,7 +435,7 @@ def embedding_settings(embedding_type):
graph_vector_radio = gr.Radio(choices=[True, False], value=False,
label="Graph-Vector Answer")
btn = gr.Button("Answer Question")
btn.click(fn=graph_rag,
btn.click(fn=rag_answer,
inputs=[inp, raw_radio, vector_only_radio, graph_only_radio, # pylint: disable=no-member
graph_vector_radio],
outputs=[raw_out, vector_only_out, graph_only_out, graph_vector_out])
Expand Down Expand Up @@ -472,9 +464,9 @@ def embedding_settings(embedding_type):
args = parser.parse_args()
app = FastAPI()

hugegraph_llm = create_hugegraph_llm_interface()
hugegraph_llm = init_rag_ui()

rag_web_http_api(app, graph_rag, apply_graph_configuration, apply_llm_configuration, apply_embedding_configuration)
rag_http_api(app, rag_answer, apply_graph_config, apply_llm_config, apply_embedding_config)

app = gr.mount_gradio_app(app, hugegraph_llm, path="/")
# Note: set reload to False in production environment
Expand Down

0 comments on commit 07e552d

Please sign in to comment.