Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

change black config #17

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/black.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ jobs:
- uses: actions/checkout@v4
- uses: psf/black@stable
with:
src: "hugegraph-llm/src hugegraph-python-client/src"
options: "--check --verbose --line-length 100"
src: "hugegraph-llm/src hugegraph-python-client/src"
32 changes: 20 additions & 12 deletions hugegraph-llm/src/hugegraph_llm/llms/api_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,20 +33,24 @@ def __init__(self):

@retry(tries=3, delay=1)
def generate(
self,
messages: Optional[List[Dict[str, Any]]] = None,
prompt: Optional[str] = None,
self,
messages: Optional[List[Dict[str, Any]]] = None,
prompt: Optional[str] = None,
) -> str:
if messages is None:
assert prompt is not None, "Messages or prompt must be provided."
messages = [{"role": "user", "content": prompt}]
url = self.base_url

payload = json.dumps({
"messages": messages,
})
payload = json.dumps(
{
"messages": messages,
}
)
headers = {"Content-Type": "application/json"}
response = requests.request("POST", url, headers=headers, data=payload, timeout=30)
response = requests.request(
"POST", url, headers=headers, data=payload, timeout=30
)
if response.status_code != 200:
raise Exception(
f"Request failed with code {response.status_code}, message: {response.text}"
Expand All @@ -55,10 +59,10 @@ def generate(
return response_json["content"]

def generate_streaming(
self,
messages: Optional[List[Dict[str, Any]]] = None,
prompt: Optional[str] = None,
on_token_callback: Callable = None,
self,
messages: Optional[List[Dict[str, Any]]] = None,
prompt: Optional[str] = None,
on_token_callback: Callable = None,
) -> str:
return self.generate(messages, prompt)

Expand All @@ -75,4 +79,8 @@ def get_llm_type(self) -> str:
if __name__ == "__main__":
client = ApiBotClient()
print(client.generate(prompt="What is the capital of China?"))
print(client.generate(messages=[{"role": "user", "content": "What is the capital of China?"}]))
print(
client.generate(
messages=[{"role": "user", "content": "What is the capital of China?"}]
)
)
18 changes: 14 additions & 4 deletions hugegraph-llm/src/hugegraph_llm/llms/ernie_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ def get_access_token(self):
"client_id": self.api_key,
"client_secret": self.secret_key,
}
return str(requests.post(url, params=params, timeout=2).json().get("access_token"))
return str(
requests.post(url, params=params, timeout=2).json().get("access_token")
)

@retry(tries=3, delay=1)
def generate(
Expand All @@ -56,14 +58,18 @@ def generate(
# parameter check failed, temperature range is (0, 1.0]
payload = json.dumps({"messages": messages, "temperature": 0.1})
headers = {"Content-Type": "application/json"}
response = requests.request("POST", url, headers=headers, data=payload, timeout=30)
response = requests.request(
"POST", url, headers=headers, data=payload, timeout=30
)
if response.status_code != 200:
raise Exception(
f"Request failed with code {response.status_code}, message: {response.text}"
)
response_json = json.loads(response.text)
if "error_code" in response_json:
raise Exception(f"Error {response_json['error_code']}: {response_json['error_msg']}")
raise Exception(
f"Error {response_json['error_code']}: {response_json['error_msg']}"
)
return response_json["result"]

def generate_streaming(
Expand All @@ -87,4 +93,8 @@ def get_llm_type(self) -> str:
if __name__ == "__main__":
client = ErnieBotClient()
print(client.generate(prompt="What is the capital of China?"))
print(client.generate(messages=[{"role": "user", "content": "What is the capital of China?"}]))
print(
client.generate(
messages=[{"role": "user", "content": "What is the capital of China?"}]
)
)
6 changes: 5 additions & 1 deletion hugegraph-llm/src/hugegraph_llm/llms/init_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,8 @@ def get_llm(self):
if __name__ == "__main__":
client = LLMs().get_llm()
print(client.generate(prompt="What is the capital of China?"))
print(client.generate(messages=[{"role": "user", "content": "What is the capital of China?"}]))
print(
client.generate(
messages=[{"role": "user", "content": "What is the capital of China?"}]
)
)
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,17 @@ def run(self, schema=None) -> Any:
raise ValueError("Input data is not a dictionary.")
if "vertices" not in schema or "edges" not in schema:
raise ValueError("Input data does not contain 'vertices' or 'edges'.")
if not isinstance(schema["vertices"], list) or not isinstance(schema["edges"], list):
if not isinstance(schema["vertices"], list) or not isinstance(
schema["edges"], list
):
raise ValueError("'vertices' or 'edges' in input data is not a list.")
for vertex in schema["vertices"]:
if not isinstance(vertex, dict):
raise ValueError("Vertex in input data is not a dictionary.")
if "vertex_label" not in vertex:
raise ValueError("Vertex in input data does not contain 'vertex_label'.")
raise ValueError(
"Vertex in input data does not contain 'vertex_label'."
)
if not isinstance(vertex["vertex_label"], str):
raise ValueError("'vertex_label' in vertex is not of correct type.")
for edge in schema["edges"]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ def get_cache_dir() -> str:

# Windows (hopefully)
else:
local = os.environ.get("LOCALAPPDATA", None) or os.path.expanduser("~\\AppData\\Local")
local = os.environ.get("LOCALAPPDATA", None) or os.path.expanduser(
"~\\AppData\\Local"
)
path = Path(local, "hugegraph_llm")

if not os.path.exists(path):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,18 +84,20 @@ def init_schema(self, schema):
properties = edge["properties"]
for prop in properties:
self.schema.propertyKey(prop).asText().ifNotExist().create()
self.schema.edgeLabel(edge_label).sourceLabel(source_vertex_label).targetLabel(
target_vertex_label
).properties(*properties).nullableKeys(*properties).ifNotExist().create()
self.schema.edgeLabel(edge_label).sourceLabel(
source_vertex_label
).targetLabel(target_vertex_label).properties(*properties).nullableKeys(
*properties
).ifNotExist().create()

def schema_free_mode(self, data):
self.schema.propertyKey("name").asText().ifNotExist().create()
self.schema.vertexLabel("vertex").useCustomizeStringId().properties(
"name"
).ifNotExist().create()
self.schema.edgeLabel("edge").sourceLabel("vertex").targetLabel("vertex").properties(
"name"
).ifNotExist().create()
self.schema.edgeLabel("edge").sourceLabel("vertex").targetLabel(
"vertex"
).properties("name").ifNotExist().create()

self.schema.indexLabel("vertexByName").onV("vertex").by(
"name"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ def __init__(
self._prop_to_match = prop_to_match
self._schema = ""


def run(self, context: Dict[str, Any]) -> Dict[str, Any]:
if self._client is None:
if isinstance(context.get("graph_client"), PyHugeClient):
Expand All @@ -95,7 +94,9 @@ def run(self, context: Dict[str, Any]) -> Dict[str, Any]:
graph = context.get("graph") or "hugegraph"
user = context.get("user") or "admin"
pwd = context.get("pwd") or "admin"
self._client = PyHugeClient(ip=ip, port=port, graph=graph, user=user, pwd=pwd)
self._client = PyHugeClient(
ip=ip, port=port, graph=graph, user=user, pwd=pwd
)
assert self._client is not None, "No graph for query."

keywords = context.get("keywords")
Expand Down Expand Up @@ -138,8 +139,12 @@ def run(self, context: Dict[str, Any]) -> Dict[str, Any]:
edge_labels=edge_labels_str,
)

result: List[Any] = self._client.gremlin().exec(gremlin=rag_gremlin_query)["data"]
knowledge: Set[str] = self._format_knowledge_from_query_result(query_result=result)
result: List[Any] = self._client.gremlin().exec(gremlin=rag_gremlin_query)[
"data"
]
knowledge: Set[str] = self._format_knowledge_from_query_result(
query_result=result
)

context["synthesize_context_body"] = list(knowledge)
context["synthesize_context_head"] = (
Expand All @@ -152,7 +157,9 @@ def run(self, context: Dict[str, Any]) -> Dict[str, Any]:
verbose = context.get("verbose") or False
if verbose:
print("\033[93mKNOWLEDGE FROM GRAPH:")
print("\n".join(rel for rel in context["synthesize_context_body"]) + "\033[0m")
print(
"\n".join(rel for rel in context["synthesize_context_body"]) + "\033[0m"
)

return context

Expand All @@ -171,7 +178,9 @@ def _format_knowledge_from_query_result(
for i, item in enumerate(raw_flat_rel):
if i % 2 == 0:
matched_str = (
item["id"] if use_id_to_match else item["props"][self._prop_to_match]
item["id"]
if use_id_to_match
else item["props"][self._prop_to_match]
)
if matched_str in node_cache:
flat_rel = flat_rel[:-prior_edge_str_len]
Expand Down Expand Up @@ -202,8 +211,12 @@ def _format_knowledge_from_query_result(
def _extract_labels_from_schema(self) -> Tuple[List[str], List[str]]:
schema = self._get_graph_schema()
node_props_str, edge_props_str = schema.split("\n")[:2]
node_props_str = node_props_str[len("Node properties: ") :].strip("[").strip("]")
edge_props_str = edge_props_str[len("Edge properties: ") :].strip("[").strip("]")
node_props_str = (
node_props_str[len("Node properties: ") :].strip("[").strip("]")
)
edge_props_str = (
edge_props_str[len("Edge properties: ") :].strip("[").strip("]")
)
node_labels = self._extract_label_names(node_props_str)
edge_labels = self._extract_label_names(edge_props_str)
return node_labels, edge_labels
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ def __init__(self, llm: BaseLLM):
self.llm = llm
self.result = None

def import_schema(self, from_hugegraph=None, from_extraction=None, from_user_defined=None):
def import_schema(
self, from_hugegraph=None, from_extraction=None, from_user_defined=None
):
if from_hugegraph:
self.operators.append(SchemaManager(from_hugegraph))
elif from_user_defined:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ def __init__(
context_tail: Optional[str] = None,
):
self._llm = llm
self._prompt_template = prompt_template or DEFAULT_ANSWER_SYNTHESIZE_TEMPLATE_TMPL
self._prompt_template = (
prompt_template or DEFAULT_ANSWER_SYNTHESIZE_TEMPLATE_TMPL
)
self._question = question
self._context_body = context_body
self._context_head = context_head
Expand All @@ -69,16 +71,22 @@ def run(self, context: Dict[str, Any]) -> Dict[str, Any]:
elif isinstance(self._context_body, (list, set)):
context_body_str = "\n".join(line for line in self._context_body)
elif isinstance(self._context_body, dict):
context_body_str = "\n".join(f"{k}: {v}" for k, v in self._context_body.items())
context_body_str = "\n".join(
f"{k}: {v}" for k, v in self._context_body.items()
)
else:
context_body_str = str(self._context_body)

context_head_str = context.get("synthesize_context_head") or self._context_head or ""
context_tail_str = context.get("synthesize_context_tail") or self._context_tail or ""

context_str = (f"{context_head_str}\n" f"{context_body_str}\n" f"{context_tail_str}").strip(
"\n"
context_head_str = (
context.get("synthesize_context_head") or self._context_head or ""
)
context_tail_str = (
context.get("synthesize_context_tail") or self._context_tail or ""
)

context_str = (
f"{context_head_str}\n" f"{context_body_str}\n" f"{context_tail_str}"
).strip("\n")

prompt = self._prompt_template.format(
context_str=context_str,
Expand Down
27 changes: 21 additions & 6 deletions hugegraph-llm/src/hugegraph_llm/operators/llm_op/info_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,18 @@ def generate_extract_triple_prompt(text, schema=None) -> str:
The extracted text is: {text}"""


def fit_token_space_by_split_text(llm: BaseLLM, text: str, prompt_token: int) -> List[str]:
def fit_token_space_by_split_text(
llm: BaseLLM, text: str, prompt_token: int
) -> List[str]:
max_length = 500
allowed_tokens = llm.max_allowed_token_length() - prompt_token
chunked_data = [text[i : i + max_length] for i in range(0, len(text), max_length)]
combined_chunks = []
current_chunk = ""
for chunk in chunked_data:
if (
int(llm.num_tokens_from_string(current_chunk)) + int(llm.num_tokens_from_string(chunk))
int(llm.num_tokens_from_string(current_chunk))
+ int(llm.num_tokens_from_string(chunk))
< allowed_tokens
):
current_chunk += chunk
Expand Down Expand Up @@ -84,13 +87,19 @@ def extract_triples_by_regex_with_schema(schema, text, graph):
for vertex in schema["vertices"]:
if vertex["vertex_label"] == label and p in vertex["properties"]:
if (s, label) not in vertices_dict:
vertices_dict[(s, label)] = {"name": s, "label": label, "properties": {p: o}}
vertices_dict[(s, label)] = {
"name": s,
"label": label,
"properties": {p: o},
}
else:
vertices_dict[(s, label)]["properties"].update({p: o})
break
for edge in schema["edges"]:
if edge["edge_label"] == label:
graph["edges"].append({"start": s, "end": o, "type": label, "properties": {}})
graph["edges"].append(
{"start": s, "end": o, "type": label, "properties": {}}
)
break
graph["vertices"] = list(vertices_dict.values())

Expand All @@ -105,13 +114,19 @@ def __init__(
self.text = text

def run(self, schema=None) -> Dict[str, List[Any]]:
prompt_token = self.llm.num_tokens_from_string(generate_extract_triple_prompt("", schema))
prompt_token = self.llm.num_tokens_from_string(
generate_extract_triple_prompt("", schema)
)

chunked_text = fit_token_space_by_split_text(
llm=self.llm, text=self.text, prompt_token=int(prompt_token)
)

result = {"vertices": [], "edges": [], "schema": schema} if schema else {"triples": []}
result = (
{"vertices": [], "edges": [], "schema": schema}
if schema
else {"triples": []}
)
for chunk in chunked_text:
proceeded_chunk = self.extract_triples_by_llm(schema, chunk)
print(f"[LLM] input: {chunk} \n output:{proceeded_chunk}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ def __init__(
self._query = text
self._language = language.lower()
self._max_keywords = max_keywords
self._extract_template = extract_template or DEFAULT_KEYWORDS_EXTRACT_TEMPLATE_TMPL
self._extract_template = (
extract_template or DEFAULT_KEYWORDS_EXTRACT_TEMPLATE_TMPL
)
self._expand_template = expand_template or DEFAULT_KEYWORDS_EXPAND_TEMPLATE_TMPL

def run(self, context: Dict[str, Any]) -> Dict[str, Any]:
Expand Down Expand Up @@ -132,7 +134,11 @@ def _extract_keywords_from_response(
sub_tokens = re.findall(r"\w+", token)
if len(sub_tokens) > 1:
results.update(
{w for w in sub_tokens if w not in NLTKHelper().stopwords(lang=self._language)}
{
w
for w in sub_tokens
if w not in NLTKHelper().stopwords(lang=self._language)
}
)

return results
Loading
Loading