Skip to content

Commit

Permalink
getting working
Browse files Browse the repository at this point in the history
  • Loading branch information
puririshi98 committed Nov 26, 2024
1 parent 4b664a2 commit 55ff148
Showing 1 changed file with 7 additions and 12 deletions.
19 changes: 7 additions & 12 deletions torch_geometric/nn/nlp/txt2kg.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,7 @@ def __init__(
self.system_prompt = system_prompt
else:
assert NVIDIA_API_KEY != '', "Please pass NVIDIA_API_KEY or set local_small_lm flag to True"
global GLOBAL_NIM_KEY
print("inside init, GLOBAL_NIM_KEY=", GLOBAL_NIM_KEY)
GLOBAL_NIM_KEY = NVIDIA_API_KEY
print("inside init after update, GLOBAL_NIM_KEY=", GLOBAL_NIM_KEY)
self.NVIDIA_API_KEY = NVIDIA_API_KEY

self.chunk_size = 512
# useful for approximating recall of subgraph retrieval algos
Expand Down Expand Up @@ -101,21 +98,19 @@ def add_doc_2_KG(
mp.spawn(
multiproc_helper,
args=(in_chunks_per_proc, outs_per_proc, parse_n_check_triples,
chunk_to_triples_str_cloud))
chunk_to_triples_str_cloud, self.NVIDIA_API_KEY))
self.relevant_triples[key] = []
for proc_i_out in outs_per_proc.values():
self.relevant_triples[key] += proc_i_out
self.doc_id_counter += 1


def chunk_to_triples_str_cloud(txt: str) -> str:
def chunk_to_triples_str_cloud(txt: str, GLOBAL_NIM_KEY='') -> str:
global CLIENT_INITD
if not CLIENT_INITD:
# We use NIMs since most PyG users may not be able to run a 70B+ model
from openai import OpenAI
global GLOBAL_NIM_KEY
global CLIENT
print("GLOBAL_NIM_KEY inside func=", GLOBAL_NIM_KEY)
CLIENT = OpenAI(base_url="https://integrate.api.nvidia.com/v1",
api_key=GLOBAL_NIM_KEY)
global NIM_MODEL
Expand Down Expand Up @@ -163,16 +158,16 @@ def parse_n_check_triples(triples_str: str) -> List[Tuple[str, str, str]]:
return processed


def llm_then_python_parse(chunks, py_fn, llm_fn):
def llm_then_python_parse(chunks, py_fn, llm_fn, **kwargs):
relevant_triples = []
for chunk in chunks:
relevant_triples += py_fn(llm_fn(chunk))
relevant_triples += py_fn(llm_fn(chunk, **kwargs))
return relevant_triples


def multiproc_helper(rank, in_chunks_per_proc, outs_per_proc, py_fn, llm_fn):
def multiproc_helper(rank, in_chunks_per_proc, outs_per_proc, py_fn, llm_fn, NIM_KEY):
outs_per_proc[rank] = llm_then_python_parse(in_chunks_per_proc[rank],
py_fn, llm_fn)
py_fn, llm_fn, GLOBAL_NIM_KEY=NIM_KEY)


def get_num_procs():
Expand Down

0 comments on commit 55ff148

Please sign in to comment.