Skip to content

Commit

Permalink
fix codestyle
Browse files Browse the repository at this point in the history
  • Loading branch information
simon824 committed Feb 2, 2024
1 parent 669cf2c commit 7277bb5
Show file tree
Hide file tree
Showing 8 changed files with 28 additions and 30 deletions.
5 changes: 3 additions & 2 deletions hugegraph-llm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,11 @@ builder = KgBuilder(LLMs().get_llm())
)
```

2. **Import Schema**: The `import_schema` method is used to import a schema from a source. The source can be a HugeGraph instance,a user-defined schema or an extraction result. The method `print_result` can be chained to print the result.
2. **Import Schema**: The `import_schema` method is used to import a schema from a source. The source can be a HugeGraph instance, a user-defined schema or an extraction result. The method `print_result` can be chained to print the result.

```python
# Import schema from a HugeGraph instance
import_schema(from_hugegraph="talent_graph").print_result()
import_schema(from_hugegraph="xxx").print_result()
# Import schema from an extraction result
import_schema(from_extraction="xxx").print_result()
# Import schema from user-defined schema
Expand Down
6 changes: 3 additions & 3 deletions hugegraph-llm/examples/build_kg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@

(
builder
.import_schema(from_hugegraph="talent_graph").print_result()
# .import_schema(from_extraction="fefe").print_result().run()
# .import_schema(from_input=schema).print_result()
.import_schema(from_hugegraph="xxx").print_result()
# .import_schema(from_extraction="xxx").print_result()
# .import_schema(from_user_defined=xxx).print_result()
.extract_triples(TEXT).print_result()
.disambiguate_word_sense()
.commit_to_hugegraph()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,21 @@ def __init__(self, data):
self.data = data

def run(self, schema=None) -> Any:
data = self.data or schema
if not isinstance(data, dict):
schema = self.data or schema
if not isinstance(schema, dict):
raise ValueError("Input data is not a dictionary.")
if "vertices" not in data or "edges" not in data:
if "vertices" not in schema or "edges" not in schema:
raise ValueError("Input data does not contain 'vertices' or 'edges'.")
if not isinstance(data["vertices"], list) or not isinstance(data["edges"], list):
if not isinstance(schema["vertices"], list) or not isinstance(schema["edges"], list):
raise ValueError("'vertices' or 'edges' in input data is not a list.")
for vertex in data["vertices"]:
for vertex in schema["vertices"]:
if not isinstance(vertex, dict):
raise ValueError("Vertex in input data is not a dictionary.")
if "vertex_label" not in vertex:
raise ValueError("Vertex in input data does not contain 'vertex_label'.")
if not isinstance(vertex["vertex_label"], str):
raise ValueError("'vertex_label' in vertex is not of correct type.")
for edge in data["edges"]:
for edge in schema["edges"]:
if not isinstance(edge, dict):
raise ValueError("Edge in input data is not a dictionary.")
if (
Expand All @@ -60,4 +60,4 @@ def run(self, schema=None) -> Any:
"'edge_label', 'source_vertex_label', 'target_vertex_label' "
"in edge is not of correct type."
)
return data
return schema
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,7 @@ def schema_free_mode(self, data):
).secondary().ifNotExist().create()

for item in data:
s = item[0].strip()
p = item[1].strip()
o = item[2].strip()
s, p, o = (element.strip() for element in item)
s_id = self.client.graph().addVertex("vertex", {"name": s}, id=s).id
t_id = self.client.graph().addVertex("vertex", {"name": o}, id=o).id
self.client.graph().addEdge("edge", s_id, t_id, {"name": p})
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from typing import Dict, List, Any

from hugegraph_llm.llms.base import BaseLLM
from hugegraph_llm.operators.llm_op.info_extract import extract_by_regex
from hugegraph_llm.operators.llm_op.info_extract import extract_triples_by_regex


def generate_disambiguate_prompt(triples):
Expand Down Expand Up @@ -48,6 +48,6 @@ def run(self, data: Dict) -> Dict[str, List[Any]]:
prompt = generate_disambiguate_prompt(triples)
llm_output = self.llm.generate(prompt=prompt)
data = {"triples": []}
extract_by_regex(llm_output, data)
extract_triples_by_regex(llm_output, data)
print(f"LLM input:{prompt} \n output: {llm_output} \n data: {data}")
return data
12 changes: 6 additions & 6 deletions hugegraph-llm/src/hugegraph_llm/operators/llm_op/info_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,13 @@ def fit_token_space_by_split_text(llm: BaseLLM, text: str, prompt_token: int) ->
return combined_chunks


def extract_by_regex(text, triples):
def extract_triples_by_regex(text, triples):
text = text.replace("\\n", " ").replace("\\", " ").replace("\n", " ")
pattern = r"\((.*?), (.*?), (.*?)\)"
triples["triples"] += re.findall(pattern, text)


def extract_by_regex_with_schema(schema, text, graph):
def extract_triples_by_regex_with_schema(schema, text, graph):
text = text.replace("\\n", " ").replace("\\", " ").replace("\n", " ")
pattern = r"\((.*?), (.*?), (.*?)\) - ([^ ]*)"
matches = re.findall(pattern, text)
Expand Down Expand Up @@ -113,14 +113,14 @@ def run(self, schema=None) -> Dict[str, List[Any]]:

result = {"vertices": [], "edges": [], "schema": schema} if schema else {"triples": []}
for chunk in chunked_text:
proceeded_chunk = self.extract_by_llm(schema, chunk)
proceeded_chunk = self.extract_triples_by_llm(schema, chunk)
print(f"[LLM] input: {chunk} \n output:{proceeded_chunk}")
if schema:
extract_by_regex_with_schema(schema, proceeded_chunk, result)
extract_triples_by_regex_with_schema(schema, proceeded_chunk, result)
else:
extract_by_regex(proceeded_chunk, result)
extract_triples_by_regex(proceeded_chunk, result)
return result

def extract_by_llm(self, schema, chunk):
def extract_triples_by_llm(self, schema, chunk):
prompt = generate_extract_triple_prompt(chunk, schema)
return self.llm.generate(prompt=prompt)
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,7 @@ def setUp(self):
],
},
}
self.llm = LLMs().get_llm()
# self.llm = None
self.llm = None
self.disambiguate_data = DisambiguateData(self.llm)

def test_run(self):
Expand Down
10 changes: 5 additions & 5 deletions hugegraph-llm/src/tests/operators/llm_op/test_info_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@

from hugegraph_llm.operators.llm_op.info_extract import (
InfoExtract,
extract_by_regex_with_schema,
extract_by_regex,
extract_triples_by_regex_with_schema,
extract_triples_by_regex,
)


Expand All @@ -40,7 +40,7 @@ def setUp(self):
}
],
}
# self.llm = LLMs().get_llm()

self.llm = None
self.info_extract = InfoExtract(self.llm, "text")

Expand Down Expand Up @@ -74,7 +74,7 @@ def setUp(self):

def test_extract_by_regex_with_schema(self):
graph = {"vertices": [], "edges": [], "schema": self.schema}
extract_by_regex_with_schema(self.schema, self.llm_output, graph)
extract_triples_by_regex_with_schema(self.schema, self.llm_output, graph)
self.assertEqual(
graph,
{
Expand Down Expand Up @@ -120,7 +120,7 @@ def test_extract_by_regex_with_schema(self):

def test_extract_by_regex(self):
graph = {"triples": []}
extract_by_regex(self.llm_output, graph)
extract_triples_by_regex(self.llm_output, graph)
self.assertEqual(
graph,
{
Expand Down

0 comments on commit 7277bb5

Please sign in to comment.