From 7277bb559c9250e9f92f0ed7f1405e787cb4da0c Mon Sep 17 00:00:00 2001 From: simon824 Date: Fri, 2 Feb 2024 16:41:27 +0800 Subject: [PATCH] fix codestyle --- hugegraph-llm/README.md | 5 +++-- hugegraph-llm/examples/build_kg_test.py | 6 +++--- .../operators/common_op/check_schema.py | 14 +++++++------- .../operators/hugegraph_op/commit_to_hugegraph.py | 4 +--- .../operators/llm_op/disambiguate_data.py | 4 ++-- .../hugegraph_llm/operators/llm_op/info_extract.py | 12 ++++++------ .../operators/llm_op/test_disambiguate_data.py | 3 +-- .../tests/operators/llm_op/test_info_extract.py | 10 +++++----- 8 files changed, 28 insertions(+), 30 deletions(-) diff --git a/hugegraph-llm/README.md b/hugegraph-llm/README.md index cb8e3f43..493b2e7e 100644 --- a/hugegraph-llm/README.md +++ b/hugegraph-llm/README.md @@ -44,10 +44,11 @@ builder = KgBuilder(LLMs().get_llm()) ) ``` -2. **Import Schema**: The `import_schema` method is used to import a schema from a source. The source can be a HugeGraph instance,a user-defined schema or an extraction result. The method `print_result` can be chained to print the result. +2. **Import Schema**: The `import_schema` method is used to import a schema from a source. The source can be a HugeGraph instance, a user-defined schema or an extraction result. The method `print_result` can be chained to print the result. + ```python # Import schema from a HugeGraph instance -import_schema(from_hugegraph="talent_graph").print_result() +import_schema(from_hugegraph="xxx").print_result() # Import schema from an extraction result import_schema(from_extraction="xxx").print_result() # Import schema from user-defined schema diff --git a/hugegraph-llm/examples/build_kg_test.py b/hugegraph-llm/examples/build_kg_test.py index 1274ed10..3380c429 100644 --- a/hugegraph-llm/examples/build_kg_test.py +++ b/hugegraph-llm/examples/build_kg_test.py @@ -50,9 +50,9 @@ ( builder - .import_schema(from_hugegraph="talent_graph").print_result() - # .import_schema(from_extraction="fefe").print_result().run() - # .import_schema(from_input=schema).print_result() + .import_schema(from_hugegraph="xxx").print_result() + # .import_schema(from_extraction="xxx").print_result() + # .import_schema(from_user_defined=xxx).print_result() .extract_triples(TEXT).print_result() .disambiguate_word_sense() .commit_to_hugegraph() diff --git a/hugegraph-llm/src/hugegraph_llm/operators/common_op/check_schema.py b/hugegraph-llm/src/hugegraph_llm/operators/common_op/check_schema.py index bd491bdf..0228a976 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/common_op/check_schema.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/common_op/check_schema.py @@ -25,21 +25,21 @@ def __init__(self, data): self.data = data def run(self, schema=None) -> Any: - data = self.data or schema - if not isinstance(data, dict): + schema = self.data or schema + if not isinstance(schema, dict): raise ValueError("Input data is not a dictionary.") - if "vertices" not in data or "edges" not in data: + if "vertices" not in schema or "edges" not in schema: raise ValueError("Input data does not contain 'vertices' or 'edges'.") - if not isinstance(data["vertices"], list) or not isinstance(data["edges"], list): + if not isinstance(schema["vertices"], list) or not isinstance(schema["edges"], list): raise ValueError("'vertices' or 'edges' in input data is not a list.") - for vertex in data["vertices"]: + for vertex in schema["vertices"]: if not isinstance(vertex, dict): raise ValueError("Vertex in input data is not a dictionary.") if "vertex_label" not in vertex: raise ValueError("Vertex in input data does not contain 'vertex_label'.") if not isinstance(vertex["vertex_label"], str): raise ValueError("'vertex_label' in vertex is not of correct type.") - for edge in data["edges"]: + for edge in schema["edges"]: if not isinstance(edge, dict): raise ValueError("Edge in input data is not a dictionary.") if ( @@ -60,4 +60,4 @@ def run(self, schema=None) -> Any: "'edge_label', 'source_vertex_label', 'target_vertex_label' " "in edge is not of correct type." ) - return data + return schema diff --git a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/commit_to_hugegraph.py b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/commit_to_hugegraph.py index 62443c2f..558a8ba8 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/commit_to_hugegraph.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/commit_to_hugegraph.py @@ -104,9 +104,7 @@ def schema_free_mode(self, data): ).secondary().ifNotExist().create() for item in data: - s = item[0].strip() - p = item[1].strip() - o = item[2].strip() + s, p, o = (element.strip() for element in item) s_id = self.client.graph().addVertex("vertex", {"name": s}, id=s).id t_id = self.client.graph().addVertex("vertex", {"name": o}, id=o).id self.client.graph().addEdge("edge", s_id, t_id, {"name": p}) diff --git a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/disambiguate_data.py b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/disambiguate_data.py index a0879342..d279e52d 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/disambiguate_data.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/disambiguate_data.py @@ -19,7 +19,7 @@ from typing import Dict, List, Any from hugegraph_llm.llms.base import BaseLLM -from hugegraph_llm.operators.llm_op.info_extract import extract_by_regex +from hugegraph_llm.operators.llm_op.info_extract import extract_triples_by_regex def generate_disambiguate_prompt(triples): @@ -48,6 +48,6 @@ def run(self, data: Dict) -> Dict[str, List[Any]]: prompt = generate_disambiguate_prompt(triples) llm_output = self.llm.generate(prompt=prompt) data = {"triples": []} - extract_by_regex(llm_output, data) + extract_triples_by_regex(llm_output, data) print(f"LLM input:{prompt} \n output: {llm_output} \n data: {data}") return data diff --git a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/info_extract.py b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/info_extract.py index 18b4083e..ab35da2b 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/info_extract.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/info_extract.py @@ -65,13 +65,13 @@ def fit_token_space_by_split_text(llm: BaseLLM, text: str, prompt_token: int) -> return combined_chunks -def extract_by_regex(text, triples): +def extract_triples_by_regex(text, triples): text = text.replace("\\n", " ").replace("\\", " ").replace("\n", " ") pattern = r"\((.*?), (.*?), (.*?)\)" triples["triples"] += re.findall(pattern, text) -def extract_by_regex_with_schema(schema, text, graph): +def extract_triples_by_regex_with_schema(schema, text, graph): text = text.replace("\\n", " ").replace("\\", " ").replace("\n", " ") pattern = r"\((.*?), (.*?), (.*?)\) - ([^ ]*)" matches = re.findall(pattern, text) @@ -113,14 +113,14 @@ def run(self, schema=None) -> Dict[str, List[Any]]: result = {"vertices": [], "edges": [], "schema": schema} if schema else {"triples": []} for chunk in chunked_text: - proceeded_chunk = self.extract_by_llm(schema, chunk) + proceeded_chunk = self.extract_triples_by_llm(schema, chunk) print(f"[LLM] input: {chunk} \n output:{proceeded_chunk}") if schema: - extract_by_regex_with_schema(schema, proceeded_chunk, result) + extract_triples_by_regex_with_schema(schema, proceeded_chunk, result) else: - extract_by_regex(proceeded_chunk, result) + extract_triples_by_regex(proceeded_chunk, result) return result - def extract_by_llm(self, schema, chunk): + def extract_triples_by_llm(self, schema, chunk): prompt = generate_extract_triple_prompt(chunk, schema) return self.llm.generate(prompt=prompt) diff --git a/hugegraph-llm/src/tests/operators/llm_op/test_disambiguate_data.py b/hugegraph-llm/src/tests/operators/llm_op/test_disambiguate_data.py index 350eeba2..ba800ce1 100644 --- a/hugegraph-llm/src/tests/operators/llm_op/test_disambiguate_data.py +++ b/hugegraph-llm/src/tests/operators/llm_op/test_disambiguate_data.py @@ -74,8 +74,7 @@ def setUp(self): ], }, } - self.llm = LLMs().get_llm() - # self.llm = None + self.llm = None self.disambiguate_data = DisambiguateData(self.llm) def test_run(self): diff --git a/hugegraph-llm/src/tests/operators/llm_op/test_info_extract.py b/hugegraph-llm/src/tests/operators/llm_op/test_info_extract.py index 596b0896..cd84521e 100644 --- a/hugegraph-llm/src/tests/operators/llm_op/test_info_extract.py +++ b/hugegraph-llm/src/tests/operators/llm_op/test_info_extract.py @@ -19,8 +19,8 @@ from hugegraph_llm.operators.llm_op.info_extract import ( InfoExtract, - extract_by_regex_with_schema, - extract_by_regex, + extract_triples_by_regex_with_schema, + extract_triples_by_regex, ) @@ -40,7 +40,7 @@ def setUp(self): } ], } - # self.llm = LLMs().get_llm() + self.llm = None self.info_extract = InfoExtract(self.llm, "text") @@ -74,7 +74,7 @@ def setUp(self): def test_extract_by_regex_with_schema(self): graph = {"vertices": [], "edges": [], "schema": self.schema} - extract_by_regex_with_schema(self.schema, self.llm_output, graph) + extract_triples_by_regex_with_schema(self.schema, self.llm_output, graph) self.assertEqual( graph, { @@ -120,7 +120,7 @@ def test_extract_by_regex_with_schema(self): def test_extract_by_regex(self): graph = {"triples": []} - extract_by_regex(self.llm_output, graph) + extract_triples_by_regex(self.llm_output, graph) self.assertEqual( graph, {