Skip to content

Commit

Permalink
Merge pull request #17 from Wenbintum/BonDNet
Browse files Browse the repository at this point in the history
HiPRGen-BonDNet lmdbs dispatcher level
  • Loading branch information
samblau authored Dec 18, 2023
2 parents 2e5e0eb + 8adf0b4 commit 182e0a6
Show file tree
Hide file tree
Showing 8 changed files with 1,487 additions and 145 deletions.
475 changes: 425 additions & 50 deletions HiPRGen/lmdb_dataset.py

Large diffs are not rendered by default.

76 changes: 56 additions & 20 deletions HiPRGen/reaction_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,13 @@ def log_message(*args, **kwargs):
'[' + strftime('%H:%M:%S', localtime()) + ']',
*args, **kwargs)

#restructure input of dispatcher
def dispatcher(
mol_entries,
dgl_molecules_dict,
grapher_features,
dispatcher_payload
mol_entries, #1
#dgl_molecules_dict,
#grapher_features,
dispatcher_payload, #2
#reaction_lmdb_path
):

comm = MPI.COMM_WORLD
Expand Down Expand Up @@ -137,14 +139,17 @@ def dispatcher(

#### HY
## initialize preprocess data

rxn_networks_g = rxn_networks_graph(
mol_entries,
dgl_molecules_dict,
grapher_features,
dispatcher_payload.bondnet_test
)
####

# #wx: writting lmdbs in dispatcher ?
# #wx: each worker needs to initlize rxn_networks_graph at worker level.
# rxn_networks_g = rxn_networks_graph(
# mol_entries,
# dgl_molecules_dict,
# grapher_features,
# dispatcher_payload.bondnet_test,
# reaction_lmdb_path #wx. different
# )
# ####

log_message("initializing report generator")

Expand All @@ -159,6 +164,7 @@ def dispatcher(
worker_states = {}

worker_ranks = [i for i in range(comm.Get_size()) if i != DISPATCHER_RANK]
print("worker_states",worker_states)

for i in worker_ranks:
worker_states[i] = WorkerState.INITIALIZING
Expand All @@ -170,6 +176,7 @@ def dispatcher(

log_message("all workers running")

#global index which is different with local index in worker
reaction_index = 0

log_message("handling requests")
Expand Down Expand Up @@ -206,6 +213,7 @@ def dispatcher(
tag = status.Get_tag()
rank = status.Get_source()

#this is the last step when worker is out of work.
if tag == SEND_ME_A_WORK_BATCH:
if len(work_batch_list) == 0:
comm.send(None, dest=rank, tag=HERE_IS_A_WORK_BATCH)
Expand All @@ -221,9 +229,11 @@ def dispatcher(
": group ids:",
group_id_0, group_id_1
)


elif tag == NEW_REACTION_DB:
#this is where worker is doing things. found a good reation and send to dispatcher.
#if this is correct, then create_rxn_networks_graph operates on worker instead of dispatcher.
#This is the reason why adding samples one by one. QA: where is batch of reactions?
#ten reactions, first filter out, second send , next eight
elif tag == NEW_REACTION_DB:
reaction = data
rn_cur.execute(
insert_reaction,
Expand All @@ -240,8 +250,10 @@ def dispatcher(
reaction['is_redox']
))

# Create reaction graph + add to LMDB
rxn_networks_g.create_rxn_networks_graph(reaction, reaction_index)
# # # Create reaction graph + add to LMDB
# rxn_networks_g.create_rxn_networks_graph(reaction, reaction_index) #wx in worker level

#dispatch tracks global index, worker tracks local index in that batch.

reaction_index += 1
if reaction_index % dispatcher_payload.commit_frequency == 0:
Expand Down Expand Up @@ -275,17 +287,37 @@ def dispatcher(


def worker(
mol_entries,
worker_payload
mol_entries, #input of worker
worker_payload,
dgl_molecules_dict,
grapher_features,
reaction_lmdb_path

):

# import pdb
# pdb.set_trace()

local_reaction_idx = 0 #wx add local_idx

comm = MPI.COMM_WORLD
con = sqlite3.connect(worker_payload.bucket_db_file)
cur = con.cursor()


comm.send(None, dest=DISPATCHER_RANK, tag=INITIALIZATION_FINISHED)

rank = comm.Get_rank() #get id of that worker

lmdb_name_i = reaction_lmdb_path.split(".lmdb")[0] + "_" + str(rank) + ".lmdb"

rxn_networks_g = rxn_networks_graph(
mol_entries,
dgl_molecules_dict,
grapher_features,
#dispatcher_payload.bondnet_test, #can be removed
lmdb_name_i #wx. different
)

while True:
comm.send(None, dest=DISPATCHER_RANK, tag=SEND_ME_A_WORK_BATCH)
work_batch = comm.recv(source=DISPATCHER_RANK, tag=HERE_IS_A_WORK_BATCH)
Expand Down Expand Up @@ -352,6 +384,10 @@ def worker(
dest=DISPATCHER_RANK,
tag=NEW_REACTION_DB)

#comm.send, send reaction to dispatchers.

rxn_networks_g.create_rxn_networks_graph(reaction, local_reaction_idx)
local_reaction_idx+=1


if run_decision_tree(reaction,
Expand Down
Loading

0 comments on commit 182e0a6

Please sign in to comment.