From 434838b7a8040eb76a0ba1264385c30106dcbaa2 Mon Sep 17 00:00:00 2001 From: Wenbin Xu Date: Thu, 7 Dec 2023 11:44:58 -0800 Subject: [PATCH 1/3] HiPRGen-BonDNet lmdb dispatcher version --- HiPRGen/.reaction_filter.py.swp | Bin 0 -> 16384 bytes HiPRGen/.reaction_filter_new.py.swp | Bin 0 -> 16384 bytes HiPRGen/lmdb_dataset.py | 308 ++++++++++++++++++++-- HiPRGen/reaction_filter.py | 9 +- HiPRGen/reaction_filter_new.py | 396 ++++++++++++++++++++++++++++ HiPRGen/rxn_networks_graph.py | 166 +++++++----- HiPRGen/species_filter.py | 60 ++++- run_network_generation.py | 7 +- test.py | 18 +- 9 files changed, 874 insertions(+), 90 deletions(-) create mode 100644 HiPRGen/.reaction_filter.py.swp create mode 100644 HiPRGen/.reaction_filter_new.py.swp create mode 100644 HiPRGen/reaction_filter_new.py diff --git a/HiPRGen/.reaction_filter.py.swp b/HiPRGen/.reaction_filter.py.swp new file mode 100644 index 0000000000000000000000000000000000000000..cbcbd2540f6c0974107d1a92d71c7a6bac7ccdb8 GIT binary patch literal 16384 zcmeHNO>87b6|TUBzYqc>2rgMF4$jOh!>reKl;Fr(#&*XZc4lnN%*JdMt5wroGgE9& zcW+nEj=e;nhzk)xAqfYBga9ECMIZ%`B5~jX2qA6>2u>Uk2q~O60V3h6>Yw@X;#f#r zqH6imxT{{hdiB+-SFfwPyYyJCK_9BD5?t>mJFZb^uWX%1p@DE3->ZHqG zT6%D~;ze4=34&k96rOik%+BwKz1BAOKd}+`TTLEk_va(doLB_D-4kBSqe^&mf=*Lq z7%&W+mVrCT%JPLJC4K+Gee~`RUOsI@BWD;e3>XFs1BL;^fMLKeU>GnAyagDD!!zWI z7{}Yvv1HHMx25;l?;GjmrR6v0>R(UmpG?dDJXinr^!~l+Z$DL>=}(Kzk72+tU>GnA z7zPXjh5^HXVZbn87%&VN1`GrL2L_l+$RhMEssjML|IeQP|9Tf8F9BZyn!rWi9^m!& z5%L${r@(iCXMnE)PXi9n04@Qa25Nu>oCWR${&6QEzXQGxoCWR${*42HUjg3(Ubq8w zzyWXt_zbWQJOX?a_}hC3`5Evs@MGXdz&8O21b`1b4BQ6%>D`3<8h8rW2Nr?9zl)Hc z1J3{lKm~Z?orJssd>i-%@O9vM;5pz4;FG|Ez$)+&;1BO0p+t0BDQs94e>6UCz;`DO2r9$pF45U6*6S8GvJXO zvLi2GZY5RfcGV{*Jn3j0MAz)D60|)bV<@_Mpe+CgL)N~Y+z4mxg zxI`R}NlEuXGLPCZgMV25Xb2KX)7rONR&}#e+ilwQ-Rhm#;S=WQnZL7IuH>`QB7fxk4 zI*Iu85r;d^D9vZ5s!5-k%_Zm!eD5fsKJXoWs*aSPnwldzU^h7J317$ocM&KMJD4YK zVqEwJg{FRm!Sy8}TebF{V##XRt!ncU$Wq;RcDHudl=qGtwPci9$!XsUy3C_DSR@oN zl6HBFkHz4R3kwU{m9(*wP=z=G zirKAN$TN?yg`1Ck-YUA+sx_*uEA*0eCE;xo>~}2MsczJ*DR(Rq2l`^3`Xi z>l-U78Vf8c6-<2|t3ZUH$y_+jSa=@R7UhQsnnSK4fftM`SeGKELQ;Wc6vf~ijai&` zt1L&YQi!c&$f6^9jU%_r7U%TJ8SaiVMwCBJGReurKa&PB z#J+&Q@ZaWfhS%JjoYt&lhSg-GYWPmza+M&(1Ms2J3&o4erRe(zh$m3XabR+G&FL=MdM1Yzx)1A^Pq{CXNQ@bGN1-Pi z_!mYM_ysMFCO%j3K}G!BsPRBJ1B85rsTpI}*3v$Qka8{tQI{Vk`01SF89DP5+@eJNp#Mgj(uf{R`8-N zPCmQVGHFkl!k47~Li zz-igB3#a(&w|aXsc@s6=1ZOMp`rMI!dGa=ET9lS%Plofl!e{+L<#U1Xi&yc2px7(v z*OWbVnv$PE;AP)n%#IFjZACFVkT~80&f|s^b&59ijxFzb{-5MLE<1Usq_3Q2PeP_% k5zaTy@8|TWGoj(sTf$s!{-2pwgwvG*XGQZzGE;;859QZ8Jpcdz literal 0 HcmV?d00001 diff --git a/HiPRGen/.reaction_filter_new.py.swp b/HiPRGen/.reaction_filter_new.py.swp new file mode 100644 index 0000000000000000000000000000000000000000..05f29ca899322cef97a635f3c24942d92ccfdec9 GIT binary patch literal 16384 zcmeHNU5p!76~1lLmcNpUR-!KrH@ix0ryFNC*$TSVD(sKDapU!_yiS)iEW>!_+8#2V zap#WX-6T*Dyg-X8S_BAz03pf)LWKu@BoGoHL=g`ZDwPM2kPyEQzynWEk?@^6GuLB# zvsoefQjO(L@Kex&}E`2Ki&eNQ|(67V$eao{Py0#<=Dz#~8zC;^9ndx3x5!PrZ{bHE|sUf@4( zV(cftmw~Uo5p}>8a1MAEun9Z_ybbu<8yNdO@Gan*z&C);0ut~64|oS~6Yz)EGxk&9 zqrfI`1o-Fc82cgcG_VB}for!j_9E~F;4{D{foFkdfGfa*zyrW6a1ZeEZH)a2_$BZn z@I~M$;7Q;La2YrO90qO$4gf#96}ke?0}^-)I0qaAUOB+nzk#m-&j6nQwgC&+02YA- z;GMuB;4a{|n6MuMKLS1nd<19%Rp1zK7`P3%1-Kdb3+DD0z>`1|SOOLSHRq8}aeM+H z`8;p*tP3JEtX@zQ+rsLFB0sTFs+5-M{3w68wzkS!f#2&Imf!P2IZ6#JLK?PXg@Zbj8^Q|%2Oaq2 z46{X(4~EL~NyqEz6B zM^@*KY~-f6>szL)N(%+F$!+B#b-ID$g?Xic>$i-KkkV|4{OAyK@E#iB;c~USak{h& zb|VsMXRFn6bp`h`ssZZj<%LT5v8ZsGS*|H1-^5gjU?Vi)WhJlI!BntnY162c7MAMe zwW?8BTUm*^%(6)?!47ZiaGGHeNFMq;Y>N~}yx$g{o<-ip7;tJ>&~m&Mw}s_MyjaA; zK!}lnmKUMx$^U%}yD=*Mbg5P{${WT4`Pms`aiP9+TJbZmV=%d7N@A6=!4|G_QD7EY z+~jxziic*yh0+_P>awv~8kIV;FU>JXARg}tEQ%qadUh5R>8aVAhFsrscOvS2&l39@ zh#Jz=?C`dEQSheYIkGKmj%(5i}y{dWdp$qg1Q&a<#suolf#c7b<5<8$5q% zitqI;FvpW32#qA@B+|9`!b8dstS|L(B%0d(c5+Ww3|(3&)p&KS&Z}oD6_&2FR}cmR zgp*)Uij`H-tp}lp8rUm?X34KTng>|4Nweg-NAqsr+n7P)MA8=5QSFEINqWreyN6=Z z^o9o6dkozpHOSsQ5TzQ3?0(aT^^-TUW6&cRP2Rxr#3SID+V)Yc_zZ zye)zKdWqK;7AvKcJLZ`Md7;pLCJhx!Vt2#(#)=C14i*&!Q%{5xh!8YQ8;;X=To-GL zi){qWj-W{3`u!r-r4w>TatF&O2*Eizo1q?-G)FB7#G|BR20Q$Mz&o7D;!LlcjtEV9 zA#`Zr=+&WXN7mlR;us+r%X*w-rpd&QK?m8za{+QV6WW;`7oLnFF=pKEbzR4Tf59r>%ckGyh$wtdi8!hFT@yp)DL~lcFL)0>jLI}?%5;oQDZav!JE1AoB;~$q!vE7BnZ;Z#; zK@y!HS~c%rKOYAy-gJFGh>8V5K~hYc(5RF)#Yw5&B~R&H^c1HCEY>~FR7SO;&}hW& z<_C$inxOKPXo)0BAkhgz-*>gnT@iG8*ccN7ku@4T> z=65?h9UKIwyCAg*=SNQ5rUr+}>889ed@!Tya_{zX#rOG+mp_m5xqP$0o7m_k-QNkwlDs{!b3GA7>c#VK%DE z*!_qbCb|y$;kc{cNY1P7hG_*xt;RVa{y!UcmE7P2EjbCz@j31aL<2_!0!N){M42n{ zTEqo@c%C1d*($uBrAMUSDQSu|oG8lO#OcZ*dE=Dk(mEx-cWI{O(O~*!xS=XW9u6ih zYsD$jMXQTc)fu`hkkORQGBc@ws933d{1@i*RPxvB%r5IPZZuQGaT0~Z keBk%bPj5}p1$+)KDpCyZ;}d$sAyjQ3k-38EdT%j literal 0 HcmV?d00001 diff --git a/HiPRGen/lmdb_dataset.py b/HiPRGen/lmdb_dataset.py index f56192e..8a7111d 100644 --- a/HiPRGen/lmdb_dataset.py +++ b/HiPRGen/lmdb_dataset.py @@ -17,22 +17,24 @@ import glob -class LmdbDataset(Dataset): +class LmdbBaseDataset(Dataset): + """ - Dataset class to + Dataset class to 1. write Reaction networks objecs to lmdb 2. load lmdb files """ + def __init__(self, config, transform=None): - super(LmdbDataset, self).__init__() + super(LmdbBaseDataset, self).__init__() self.config = config self.path = Path(self.config["src"]) - #Get metadata in case - #self.metadata_path = self.path.parent / "metadata.npz" + # Get metadata in case + # self.metadata_path = self.path.parent / "metadata.npz" self.env = self.connect_db(self.path) - + # If "length" encoded as ascii is present, use that # If there are additional properties, there must be length. length_entry = self.env.begin().get("length".encode("ascii")) @@ -45,8 +47,8 @@ def __init__(self, config, transform=None): self._keys = list(range(num_entries)) self.num_samples = num_entries - - #Get portion of total dataset + + # Get portion of total dataset self.sharded = False if "shard" in self.config and "total_shards" in self.config: self.sharded = True @@ -58,8 +60,8 @@ def __init__(self, config, transform=None): # limit each process to see a subset of data based off defined shard self.available_indices = self.shards[self.config.get("shard", 0)] self.num_samples = len(self.available_indices) - - #TODO + + # TODO self.transform = transform def __len__(self): @@ -71,13 +73,11 @@ def __getitem__(self, idx): idx = self.available_indices[idx] #!CHECK, _keys should be less then total numbers of keys as there are more properties. - datapoint_pickled = self.env.begin().get( - f"{self._keys[idx]}".encode("ascii") - ) - + datapoint_pickled = self.env.begin().get(f"{self._keys[idx]}".encode("ascii")) + data_object = pickle.loads(datapoint_pickled) - #TODO + # TODO if self.transform is not None: data_object = self.transform(data_object) @@ -87,7 +87,7 @@ def connect_db(self, lmdb_path=None): env = lmdb.open( str(lmdb_path), subdir=False, - readonly=True, + readonly=False, lock=False, readahead=True, meminit=False, @@ -105,6 +105,36 @@ def close_db(self): def get_metadata(self, num_samples=100): pass + +class LmdbMoleculeDataset(LmdbBaseDataset): + def __init__(self, config, transform=None): + super(LmdbMoleculeDataset, self).__init__(config=config, transform=transform) + + @property + def charges(self): + charges = self.env.begin().get("charges".encode("ascii")) + return pickle.loads(charges) + + @property + def ring_sizes(self): + ring_sizes = self.env.begin().get("ring_sizes".encode("ascii")) + return pickle.loads(ring_sizes) + + @property + def elements(self): + elements = self.env.begin().get("elements".encode("ascii")) + return pickle.loads(elements) + + @property + def feature_info(self): + feature_info = self.env.begin().get("feature_info".encode("ascii")) + return pickle.loads(feature_info) + + +class LmdbReactionDataset(LmdbBaseDataset): + def __init__(self, config, transform=None): + super(LmdbReactionDataset, self).__init__(config=config, transform=transform) + @property def dtype(self): dtype = self.env.begin().get("dtype".encode("ascii")) @@ -131,6 +161,121 @@ def std(self): return pickle.loads(std) + +# class LmdbDataset(Dataset): +# """ +# Dataset class to +# 1. write Reaction networks objecs to lmdb +# 2. load lmdb files +# """ +# def __init__(self, config, transform=None): +# super(LmdbDataset, self).__init__() + +# self.config = config +# self.path = Path(self.config["src"]) + +# #Get metadata in case +# #self.metadata_path = self.path.parent / "metadata.npz" +# self.env = self.connect_db(self.path) + +# # If "length" encoded as ascii is present, use that +# # If there are additional properties, there must be length. +# length_entry = self.env.begin().get("length".encode("ascii")) +# if length_entry is not None: +# num_entries = pickle.loads(length_entry) +# else: +# # Get the number of stores data from the number of entries +# # in the LMDB +# num_entries = self.env.stat()["entries"] + +# self._keys = list(range(num_entries)) +# self.num_samples = num_entries + +# #Get portion of total dataset +# self.sharded = False +# if "shard" in self.config and "total_shards" in self.config: +# self.sharded = True +# self.indices = range(self.num_samples) +# # split all available indices into 'total_shards' bins +# self.shards = np.array_split( +# self.indices, self.config.get("total_shards", 1) +# ) +# # limit each process to see a subset of data based off defined shard +# self.available_indices = self.shards[self.config.get("shard", 0)] +# self.num_samples = len(self.available_indices) + +# #TODO +# self.transform = transform + +# def __len__(self): +# return self.num_samples + +# def __getitem__(self, idx): +# # if sharding, remap idx to appropriate idx of the sharded set +# if self.sharded: +# idx = self.available_indices[idx] + +# #!CHECK, _keys should be less then total numbers of keys as there are more properties. +# datapoint_pickled = self.env.begin().get( +# f"{self._keys[idx]}".encode("ascii") +# ) + +# data_object = pickle.loads(datapoint_pickled) + +# #TODO +# if self.transform is not None: +# data_object = self.transform(data_object) + +# return data_object + +# def connect_db(self, lmdb_path=None): +# env = lmdb.open( +# str(lmdb_path), +# subdir=False, +# readonly=True, +# lock=False, +# readahead=True, +# meminit=False, +# max_readers=1, +# ) +# return env + +# def close_db(self): +# if not self.path.is_file(): +# for env in self.envs: +# env.close() +# else: +# self.env.close() + +# def get_metadata(self, num_samples=100): +# pass + +# @property +# def dtype(self): +# dtype = self.env.begin().get("dtype".encode("ascii")) +# return pickle.loads(dtype) + +# @property +# def feature_size(self): +# feature_size = self.env.begin().get("feature_size".encode("ascii")) +# return pickle.loads(feature_size) + +# @property +# def feature_name(self): +# feature_name = self.env.begin().get("feature_name".encode("ascii")) +# return pickle.loads(feature_name) + +# @property +# def mean(self): +# mean = self.env.begin().get("mean".encode("ascii")) +# return pickle.loads(mean) + +# @property +# def std(self): +# std = self.env.begin().get("std".encode("ascii")) +# return pickle.loads(std) + + def divide_to_list(a, b): quotient = a // b remainder = a % b @@ -280,10 +425,13 @@ def merge_lmdbs(db_paths, out_path, output_file): env_out.sync() env_out.close() +def write_to_lmdb(new_samples, current_length, lmdb_update, db_path): + """ + put new_samples into lmdbs, + update length and global features. + """ -def write_to_lmdb(new_samples, current_length, lmdb_update, db_path): - # #pid is idx of workers. # db_path, samples, pid, meta_keys = mp_args db = lmdb.open( @@ -312,7 +460,6 @@ def write_to_lmdb(new_samples, current_length, lmdb_update, db_path): txn.commit() #write properties - total_length = current_length + len(new_samples) txn=db.begin(write=True) @@ -325,5 +472,128 @@ def write_to_lmdb(new_samples, current_length, lmdb_update, db_path): txn.put(key.encode("ascii"), pickle.dumps(value, protocol=-1)) txn.commit() + db.sync() + db.close() + + +def write2moleculelmdb(mp_args + ): + """ + write molecule lmdb in parallel. + in species filter, there should be only one thread. no need parallelizations. + """ + db_path, samples, global_keys, pid = mp_args + #Samples: [mol_indices, dgl_graph, pmg] + #Global_keys: [charge, ring_sizes, elements.] + #Pid: i_th process + + db = lmdb.open( + db_path, + map_size=1099511627776 * 2, + subdir=False, + meminit=False, + map_async=True, + ) + + pbar = tqdm( + total=len(samples), + position=pid, + desc=f"Worker {pid}: Writing LMDBs", + ) + #write samples + for sample in samples: + sample_index = sample["molecule_index"] + txn = db.begin(write=True) + txn.put( + #let index of molecule identical to index of sample + f"{sample_index}".encode("ascii"), + pickle.dumps(sample, protocol=-1), + ) + pbar.update(1) + txn.commit() + + #write properties. + txn = db.begin(write=True) + txn.put("length".encode("ascii"), pickle.dumps(len(samples), protocol=-1)) + txn.commit() + + for key, value in global_keys.items(): + txn = db.begin(write=True) + txn.put(key.encode("ascii"), pickle.dumps(value, protocol=-1)) + txn.commit() + + db.sync() + db.close() + + +def dump_molecule_lmdb( + indices, + graphs, + pmgs, + charges, + ring_sizes, + elements, + lmdb_path +): + #1 load lmdb + lmdb_path = lmdb_path + lmdb_file = Path(lmdb_path) + # if lmdb_file.is_file(): + # # file exists + # print("Molecular lmdb already exists") + # else: + key_tempalte = ["molecule_index", "molecule_graph", "molecule_wrapper"] + + dataset = [{k: v for k, v in zip(key_tempalte, values)} for values in zip(indices, graphs, pmgs)] + + global_keys = { + "charges" : charges, + "ring_sizes": ring_sizes, + "elements": elements, + "feature_info" : {}, #TODO + } + + # import pdb + # pdb.set_trace() + + db = lmdb.open( + lmdb_path, + map_size=1099511627776 * 2, + subdir=False, + meminit=False, + map_async=True, + ) + + #Samples: [mol_indices, dgl_graph, pmg] + #Global_keys: [charge, ring_sizes, elements.] + pbar = tqdm( + total=len(dataset), + desc="Writing Molecular LMDBs", + ) + + #write samples + for sample in dataset: + sample_index = sample["molecule_index"] + txn = db.begin(write=True) + txn.put( + #let index of molecule identical to index of sample + f"{sample_index}".encode("ascii"), + pickle.dumps(sample, protocol=-1), + ) + pbar.update(1) + txn.commit() + + #write properties + #write length + txn = db.begin(write=True) + txn.put("length".encode("ascii"), pickle.dumps(len(dataset), protocol=-1)) + txn.commit() + + #write global keys. + for key, value in global_keys.items(): + txn = db.begin(write=True) + txn.put(key.encode("ascii"), pickle.dumps(value, protocol=-1)) + txn.commit() + db.sync() db.close() \ No newline at end of file diff --git a/HiPRGen/reaction_filter.py b/HiPRGen/reaction_filter.py index 02f2c7f..53d6d83 100644 --- a/HiPRGen/reaction_filter.py +++ b/HiPRGen/reaction_filter.py @@ -108,7 +108,8 @@ def dispatcher( mol_entries, dgl_molecules_dict, grapher_features, - dispatcher_payload + dispatcher_payload, + reaction_lmdb_path ): comm = MPI.COMM_WORLD @@ -138,11 +139,13 @@ def dispatcher( #### HY ## initialize preprocess data +#wx: writting lmdbs in dispatcher ? rxn_networks_g = rxn_networks_graph( mol_entries, dgl_molecules_dict, grapher_features, - dispatcher_payload.bondnet_test + dispatcher_payload.bondnet_test, + reaction_lmdb_path ) #### @@ -365,4 +368,4 @@ def worker( ), dest=DISPATCHER_RANK, - tag=NEW_REACTION_LOGGING) + tag=NEW_REACTION_LOGGING) \ No newline at end of file diff --git a/HiPRGen/reaction_filter_new.py b/HiPRGen/reaction_filter_new.py new file mode 100644 index 0000000..c5bd263 --- /dev/null +++ b/HiPRGen/reaction_filter_new.py @@ -0,0 +1,396 @@ +from mpi4py import MPI +from HiPRGen.rxn_networks_graph import rxn_networks_graph +from itertools import permutations, product +from HiPRGen.report_generator import ReportGenerator +import sqlite3 +from time import localtime, strftime, time +from enum import Enum +from math import floor +from HiPRGen.reaction_filter_payloads import ( + DispatcherPayload, + WorkerPayload +) + +from HiPRGen.reaction_questions import ( + run_decision_tree +) + + +""" +Phases 3 & 4 run in parallel using MPI + +Phase 3: reaction gen and filtering +input: a bucket labeled by atom count +output: a list of reactions from that bucket +description: Loop through all possible reactions in the bucket and apply the decision tree. This will run in parallel over each bucket. + +Phase 4: collating and indexing +input: all the outputs of phase 3 as they are generated +output: reaction network database +description: the worker processes from phase 3 are sending their reactions to this phase and it is writing them to DB as it gets them. We can ensure that duplicates don't get generated in phase 3 which means we don't need extra index tables on the db. + +the code in this file is designed to run on a compute cluster using MPI. +""" + + +create_metadata_table = """ + CREATE TABLE metadata ( + number_of_species INTEGER NOT NULL, + number_of_reactions INTEGER NOT NULL + ); +""" + +insert_metadata = """ + INSERT INTO metadata VALUES (?, ?) +""" + +# it is important that reaction_id is the primary key +# otherwise the network loader will be extremely slow. +create_reactions_table = """ + CREATE TABLE reactions ( + reaction_id INTEGER NOT NULL PRIMARY KEY, + number_of_reactants INTEGER NOT NULL, + number_of_products INTEGER NOT NULL, + reactant_1 INTEGER NOT NULL, + reactant_2 INTEGER NOT NULL, + product_1 INTEGER NOT NULL, + product_2 INTEGER NOT NULL, + rate REAL NOT NULL, + dG REAL NOT NULL, + dG_barrier REAL NOT NULL, + is_redox INTEGER NOT NULL + ); +""" + + +insert_reaction = """ + INSERT INTO reactions VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) +""" + +get_complex_group_sql = """ + SELECT * FROM complexes WHERE composition_id=? AND group_id=? +""" + + +# TODO: structure these global variables better +DISPATCHER_RANK = 0 + +# message tags + +# sent by workers to the dispatcher once they have finished initializing +# only sent once +INITIALIZATION_FINISHED = 0 + +# sent by workers to the dispatcher to request a new table +SEND_ME_A_WORK_BATCH = 1 + +# sent by dispatcher to workers when delivering a new table +HERE_IS_A_WORK_BATCH = 2 + +# sent by workers to the dispatcher when reaction passes db decision tree +NEW_REACTION_DB = 3 + +# sent by workers to the dispatcher when reaction passes logging decision tree +NEW_REACTION_LOGGING = 4 + +class WorkerState(Enum): + INITIALIZING = 0 + RUNNING = 1 + FINISHED = 2 + + +def log_message(*args, **kwargs): + print( + '[' + strftime('%H:%M:%S', localtime()) + ']', + *args, **kwargs) + +def dispatcher( #input of dispatcher. + mol_entries, #1 + dgl_molecules_dict, + grapher_features, + dispatcher_payload, #2 + #wx + reaction_lmdb_path +): + + comm = MPI.COMM_WORLD + work_batch_list = [] + bucket_con = sqlite3.connect(dispatcher_payload.bucket_db_file) + bucket_cur = bucket_con.cursor() + size_cur = bucket_con.cursor() + + res = bucket_cur.execute("SELECT * FROM group_counts") + for (composition_id, count) in res: + for (i,j) in product(range(count), repeat=2): + work_batch_list.append( + (composition_id, i, j)) + + composition_names = {} + res = bucket_cur.execute("SELECT * FROM compositions") + for (composition_id, composition) in res: + composition_names[composition_id] = composition + + log_message("creating reaction network db") + rn_con = sqlite3.connect(dispatcher_payload.reaction_network_db_file) + rn_cur = rn_con.cursor() + rn_cur.execute(create_metadata_table) + rn_cur.execute(create_reactions_table) + rn_con.commit() + + #### HY + ## initialize preprocess data + +#wx: writting lmdbs in dispatcher ? +#wx: each worker needs to initlize rxn_networks_graph at worker level. + rxn_networks_g = rxn_networks_graph( + mol_entries, + dgl_molecules_dict, + grapher_features, + dispatcher_payload.bondnet_test, + reaction_lmdb_path #wx. different + ) + #### + + log_message("initializing report generator") + + # since MPI processes spin lock, we don't want to have the dispathcer + # spend a bunch of time generating molecule pictures + report_generator = ReportGenerator( + mol_entries, + dispatcher_payload.report_file, + rebuild_mol_pictures=False + ) + + worker_states = {} + + worker_ranks = [i for i in range(comm.Get_size()) if i != DISPATCHER_RANK] + + for i in worker_ranks: + worker_states[i] = WorkerState.INITIALIZING + + for i in worker_states: + # block, waiting for workers to initialize + comm.recv(source=i, tag=INITIALIZATION_FINISHED) + worker_states[i] = WorkerState.RUNNING + + log_message("all workers running") + + reaction_index = 0 + + log_message("handling requests") + + batches_left_at_last_checkpoint = len(work_batch_list) + last_checkpoint_time = floor(time()) + while True: + if WorkerState.RUNNING not in worker_states.values(): + break + + current_time = floor(time()) + time_diff = current_time - last_checkpoint_time + if ( current_time % dispatcher_payload.checkpoint_interval == 0 and + time_diff > 0): + batches_left_at_current_checkpoint = len(work_batch_list) + batch_count_diff = ( + batches_left_at_last_checkpoint - + batches_left_at_current_checkpoint) + + batch_consumption_rate = batch_count_diff / time_diff + + log_message("batches remaining:", batches_left_at_current_checkpoint) + log_message("batch consumption rate:", + batch_consumption_rate, + "batches per second") + + + batches_left_at_last_checkpoint = batches_left_at_current_checkpoint + last_checkpoint_time = current_time + + + status = MPI.Status() + data = comm.recv(source=MPI.ANY_SOURCE, tag=MPI.ANY_TAG, status=status) + tag = status.Get_tag() + rank = status.Get_source() + + #this is the last step when worker is out of work. + if tag == SEND_ME_A_WORK_BATCH: + if len(work_batch_list) == 0: + comm.send(None, dest=rank, tag=HERE_IS_A_WORK_BATCH) + worker_states[rank] = WorkerState.FINISHED + else: + # pop removes and returns the last item in the list + work_batch = work_batch_list.pop() + comm.send(work_batch, dest=rank, tag=HERE_IS_A_WORK_BATCH) + composition_id, group_id_0, group_id_1 = work_batch + log_message( + "dispatched", + composition_names[composition_id], + ": group ids:", + group_id_0, group_id_1 + ) + #this is where worker is doing things. found a good reation and send to dispatcher. + #if this is correct, then create_rxn_networks_graph operates on worker instead of dispatcher. + #This is the reason why adding samples one by one. QA: where is batch of reactions? +#ten reactions, first filter out, second send , next eight + elif tag == NEW_REACTION_DB: + reaction = data + rn_cur.execute( + insert_reaction, + (reaction_index, + reaction['number_of_reactants'], + reaction['number_of_products'], + reaction['reactants'][0], + reaction['reactants'][1], + reaction['products'][0], + reaction['products'][1], + reaction['rate'], + reaction['dG'], + reaction['dG_barrier'], + reaction['is_redox'] + )) + + # # Create reaction graph + add to LMDB + rxn_networks_g.create_rxn_networks_graph(reaction, reaction_index) #wx in worker level + +#dispatch tracks global index, worker tracks local index in that batch. + + reaction_index += 1 + if reaction_index % dispatcher_payload.commit_frequency == 0: + rn_con.commit() + + + elif tag == NEW_REACTION_LOGGING: + + reaction = data[0] + decision_path = data[1] + + report_generator.emit_verbatim(decision_path) + report_generator.emit_reaction(reaction) + report_generator.emit_bond_breakage(reaction) + report_generator.emit_newline() + + + + log_message("finalzing database and generation report") + rn_cur.execute( + insert_metadata, + (len(mol_entries), + reaction_index) + ) + + + report_generator.finished() + rn_con.commit() + bucket_con.close() + rn_con.close() + + +def worker( + mol_entries, #input of worker + worker_payload +): + +#wx + local_reaction_idx = 0 + + comm = MPI.COMM_WORLD + con = sqlite3.connect(worker_payload.bucket_db_file) + cur = con.cursor() + + comm.send(None, dest=DISPATCHER_RANK, tag=INITIALIZATION_FINISHED) + +#wx + rank = comm.Get_rank() #id of that worker + + rxn_networks_g = rxn_networks_graph( + mol_entries, + dgl_molecules_dict, + grapher_features, + #dispatcher_payload.bondnet_test, + reaction_lmdb_path + rank #wx. different + ) + + + while True: + comm.send(None, dest=DISPATCHER_RANK, tag=SEND_ME_A_WORK_BATCH) + work_batch = comm.recv(source=DISPATCHER_RANK, tag=HERE_IS_A_WORK_BATCH) + + if work_batch is None: + break + + + composition_id, group_id_0, group_id_1 = work_batch + + + if group_id_0 == group_id_1: + + res = cur.execute( + get_complex_group_sql, + (composition_id, group_id_0)) + + bucket = [] + for row in res: + bucket.append((row[0],row[1])) + + iterator = permutations(bucket, r=2) + + else: + + res_0 = cur.execute( + get_complex_group_sql, + (composition_id, group_id_0)) + + bucket_0 = [] + for row in res_0: + bucket_0.append((row[0],row[1])) + + res_1 = cur.execute( + get_complex_group_sql, + (composition_id, group_id_1)) + + bucket_1 = [] + for row in res_1: + bucket_1.append((row[0],row[1])) + + iterator = product(bucket_0, bucket_1) + + + + for (reactants, products) in iterator: + reaction = { + 'reactants' : reactants, + 'products' : products, + 'number_of_reactants' : len([i for i in reactants if i != -1]), + 'number_of_products' : len([i for i in products if i != -1])} + + + decision_pathway = [] + if run_decision_tree(reaction, + mol_entries, + worker_payload.params, + worker_payload.reaction_decision_tree, + decision_pathway + ): + + comm.send( + reaction, + dest=DISPATCHER_RANK, + tag=NEW_REACTION_DB) + + #comm.send, send reaction to dispatchers. + + rxn_networks_g.create_rxn_networks_graph(reaction, local_reaction_idx) + local_reaction_idx+=1 + + + if run_decision_tree(reaction, + mol_entries, + worker_payload.params, + worker_payload.logging_decision_tree): + + comm.send( + (reaction, + '\n'.join([str(f) for f in decision_pathway]) + ), + + dest=DISPATCHER_RANK, + tag=NEW_REACTION_LOGGING) diff --git a/HiPRGen/rxn_networks_graph.py b/HiPRGen/rxn_networks_graph.py index e94b005..eb7721c 100644 --- a/HiPRGen/rxn_networks_graph.py +++ b/HiPRGen/rxn_networks_graph.py @@ -5,30 +5,34 @@ import copy from collections import defaultdict from monty.serialization import dumpfn -from bondnet.data.utils import construct_rxn_graph_empty -from HiPRGen.lmdb_dataset import LmdbDataset +from bondnet.data.utils import construct_rxn_graph_empty, create_rxn_graph +from HiPRGen.lmdb_dataset import LmdbReactionDataset import lmdb import tqdm import pickle from HiPRGen.lmdb_dataset import write_to_lmdb + class rxn_networks_graph: def __init__( self, mol_entries, dgl_molecules_dict, grapher_features, - report_file_path - ): + report_file_path, + reaction_lmdb_path #wx + ): + #wx, which one should come from molecule lmdbs? self.mol_entries = mol_entries self.dgl_mol_dict = dgl_molecules_dict self.grapher_features = grapher_features self.report_file_path = report_file_path + self.reaction_lmdb_path = reaction_lmdb_path + # initialize data self.data = {} - def create_rxn_networks_graph(self, rxn, rxn_id): """ @@ -62,7 +66,11 @@ def create_rxn_networks_graph(self, rxn, rxn_id): The goal of this function is to create a reaction graph with a reaction filtered from HiPRGen, specifically, reaction_filter.py """ - + #wx: + #rxn: 1. reactants, products, number_of_reactants, number_of_products, is_redox, dG, dG_barrier, rate, atom_mp. + #import pdb; + #pdb.set_trace() + atom_map = rxn['atom_map'] num_reactants = rxn['number_of_reactants'] num_products = rxn['number_of_products'] @@ -101,8 +109,7 @@ def create_rxn_networks_graph(self, rxn, rxn_id): # Find "total_atoms" in mapping num_tot_atoms = sum([len(i) for i in reactants]) total_atoms = [i for i in range(num_tot_atoms)] - - + # step 2: Get total_bonds which are a union of bonds in reactants and products # define a function to find total bonds in reactants or products @@ -156,9 +163,6 @@ def find_total_bonds(rxn, species, reactants, products): if rxn['is_redox']: assert len(set(reactants_total_bonds)) == len(set(products_total_bonds)) - - - # step 3: Get bond_mapping bond_mapping = [] @@ -208,12 +212,13 @@ def find_total_bonds(rxn, species, reactants, products): mappings['num_atoms_total'] = len(total_atoms) # if rxn['is_redox']: # print(f"mappings: {mappings}") - #print(f"mapping: {mappings}") + # print(f"mapping: {mappings}") # print(f"atom_map: {atom_map}") # print(f"reactants: {reactants}") # print(f"products: {products}") # grab dgl graphs of reactants or products + #wx, extract mol.feature_info. from here, and reactants_dgl_graphs = [self.dgl_mol_dict[entry_i] for entry_i in reactants_entry_ids] products_dgl_graphs = [self.dgl_mol_dict[entry_i] for entry_i in products_entry_ids] # print(f"reactants_dgl_graphs: {reactants_dgl_graphs}") @@ -229,50 +234,87 @@ def find_total_bonds(rxn, species, reactants, products): # print(f"has_bonds: {has_bonds}") # print(f"mappings: {mappings}") - # # step 5: Create a reaction graphs and features - # rxn_graph, features = create_rxn_graph( - # reactants = reactants_dgl_graphs, - # products = products_dgl_graphs, - # mappings = mappings, - # has_bonds = has_bonds, - # device = None, - # ntypes=("global", "atom", "bond"), - # ft_name="feat", - # reverse=False, - # zero_fts=True, - # ) - - # # print(f"rxn_graph: {rxn_graph}") - # if rxn['is_redox']: - # print(f"mappings: {mappings}") - # print(f"features: {features}") - # print(f"transformed_atom_map: {transformed_atom_map}") - # print(f"atom_map: {atom_map}") - - # # step 5: update reaction features to the reaction graph - # for nt, ft in features.items(): - # # print(f"nt: {nt}") - # # print(f"ft: {ft}") - # rxn_graph.nodes[nt].data.update({'ft': ft}) - - rxn_graph = construct_rxn_graph_empty(mappings) - - # step 6: save a reaction graph and dG - self.data[rxn_id] = {} # {'id': {}} - self.data[rxn_id]['rxn_graph'] = rxn_graph - self.data[rxn_id]['value'] = rxn['dG'] #torch.tensor([rxn['dG']]) - self.data[rxn_id]['mappings'] = mappings - # self.data[rxn_id]['reaction_features'] = features - + #wx, step 5: create an empty graph and set features to be all zero + rxn_graph, features = create_rxn_graph( + reactants = reactants_dgl_graphs, + products = products_dgl_graphs, + mappings = mappings, + has_bonds = has_bonds, + device = None, + ntypes=("global", "atom", "bond"), + ft_name="feat", + reverse=False, + zero_fts=True, + ) + + #wx, step 6: structure and save reaction graph. + self.data["reaction_index"] = rxn_id + self.data["reaction_graph"] = rxn_graph + self.data["reaction_feature"] = features + self.data["reaction_molecule_info"] = { + "reactants" : { "reactants" : list(rxn["reactants"]), #TODO, keep unique value if two reactants are same? + "atom_map": mappings["atom_map"][0], + "bond_map": mappings["bond_map"][0], + "init_reactants": list(rxn["reactants"]) + + }, + "products" : { "products" : list(rxn["products"]), # rxn['reactants'] (94, 96) + "atom_map": mappings["atom_map"][1], + "bond_map": mappings["bond_map"][1], + "init_products": list(rxn["products"]) + + } + } + # self.data["reaction_molecule_info"] = { + # "reactants" : { "molecule_index" : list(rxn["reactants"]), #TODO, keep unique value if two reactants are same? + # "atom_map": mappings["atom_map"][0], + # "bond_map": mappings["bond_map"][0] + + # }, + # "products" : { "molecule_index" : list(rxn["products"]), # rxn['reactants'] (94, 96) + # "atom_map": mappings["atom_map"][1], + # "bond_map": mappings["bond_map"][1] + # } + # } + self.data["label"] = torch.Tensor([rxn['dG']]) + self.data["reverse_label"] = torch.Tensor([0]) #TODO + self.data["extra_info"] = {} #TODO + self.data["mappings"] = mappings + self.data["has_bonds"] = has_bonds + + # # # print(f"rxn_graph: {rxn_graph}") + # # if rxn['is_redox']: + # # print(f"mappings: {mappings}") + # # print(f"features: {features}") + # # print(f"transformed_atom_map: {transformed_atom_map}") + # # print(f"atom_map: {atom_map}") + + #wx, since we create empty graph. just neglect this step + # # # step 5: update reaction features to the reaction graph + # # for nt, ft in features.items(): + # # # print(f"nt: {nt}") + # # # print(f"ft: {ft}") + # # rxn_graph.nodes[nt].data.update({'ft': ft}) + + # rxn_graph = construct_rxn_graph_empty(mappings) + + # # step 6: save a reaction graph and dG + # self.data[rxn_id] = {} # {'id': {}} + # self.data[rxn_id]['rxn_graph'] = rxn_graph + # self.data[rxn_id]['value'] = rxn['dG'] #torch.tensor([rxn['dG']]) + # self.data[rxn_id]['mappings'] = mappings + # # self.data[rxn_id]['reaction_features'] = features + + #### Write LMDB #### #1 load lmdb - - lmdb_path = "training_trial5.lmdb" + + lmdb_path = self.reaction_lmdb_path lmdb_file = Path(lmdb_path) if lmdb_file.is_file(): # file exists - current_lmdb = LmdbDataset({'src': lmdb_path}) + current_lmdb = LmdbReactionDataset({'src': lmdb_path}) lmdb_update = { "mean" : current_lmdb.mean, "std": current_lmdb.std, @@ -315,15 +357,19 @@ def find_total_bonds(rxn, species, reactants, products): updated_variance = (n-1)/(n)*prev_variance + (n-1)/n*(prev_mean-updated_mean)**2 + (current_y - updated_mean)**2/n lmdb_update["std"] = math.sqrt(updated_variance) - - labels = {'value': torch.tensor([rxn['dG']]), 'value_rev': torch.tensor([0]), 'id': [str(rxn_id)], "reaction_type": ['']} - data = (self.data[rxn_id]['rxn_graph'], self.data[rxn_id]['reaction_features'], labels) - # print(f"data: {data}") - # print(f"lmdb_update: {lmdb_update}") - write_to_lmdb([data], current_length, lmdb_update, lmdb_path) + write_to_lmdb([self.data], current_length, lmdb_update, lmdb_path) - - def write_data(self): - # write a json file - dumpfn(self.data, self.report_file_path) \ No newline at end of file + # import pdb + # pdb.set_trace() + + # labels = {'value': torch.tensor([rxn['dG']]), 'value_rev': torch.tensor([0]), 'id': [str(rxn_id)], "reaction_type": ['']} + # data = (self.data[rxn_id]['rxn_graph'], self.data[rxn_id]['reaction_features'], labels) + # # print(f"data: {data}") + # # print(f"lmdb_update: {lmdb_update}") + # write_to_lmdb([data], current_length, lmdb_update, lmdb_path) + +#wx no need this. + # def write_data(self): + # # write a json file + # dumpfn(self.data, self.report_file_path) \ No newline at end of file diff --git a/HiPRGen/species_filter.py b/HiPRGen/species_filter.py index 82d87cd..7f43a32 100644 --- a/HiPRGen/species_filter.py +++ b/HiPRGen/species_filter.py @@ -22,6 +22,12 @@ from bondnet.core.molwrapper import create_wrapper_mol_from_atoms_and_bonds from bondnet.utils import int_atom +#wx +from bondnet.data.utils import find_rings +from HiPRGen.lmdb_dataset import dump_molecule_lmdb +import numpy as np +import copy + """ Phase 1: species filtering input: a list of dataset entries @@ -96,6 +102,7 @@ def species_filter( coordimer_weight, species_logging_decision_tree=Terminal.DISCARD, generate_unfiltered_mol_pictures=False, + mol_lmdb_path = None, #path to molecular lmdb ): """ @@ -230,8 +237,16 @@ def collapse_isomorphism_group(g): dgl_molecules = [] extra_keys = [] + #wx, dump molecule lmdbs. + pmg_objects = [] + molecule_ind_list = [] + charge_set = set() + ring_size_set = set() + element_set = set() + + # BONDNET EDITS # BONDNET EDITS # BONDNET EDITS # BONDNET EDITS # BONDNET EDITS - for mol in mol_entries: + for ind, mol in enumerate(mol_entries): # print(f"mol: {mol.mol_graph}") molecule_grapher = get_grapher( features = extra_keys, @@ -285,6 +300,28 @@ def collapse_isomorphism_group(g): fts = dgl_molecule_graph.nodes[nt].data["feat"] print(f"features: {fts}") dgl_molecules_dict[mol.entry_id] = mol.ind + + + # #wx, collect data for molecule lmdbs + pmg_objects.append(mol_wrapper.pymatgen_mol) + molecule_ind_list.append(ind) #can use mol.ind + + _formula = mol_wrapper.pymatgen_mol.composition.formula.split() + _elements = [clean(x) for x in _formula] #otherwise override variable. + element_set.update(_elements) + atom_num = np.sum(np.array([int(clean_op(x)) for x in _formula])) + + charge = mol_wrapper.pymatgen_mol.charge + charge_set.add(charge) + bond_list = [[i[0], i[1]] for i in mol_wrapper.mol_graph.graph.edges] + cycles = find_rings(atom_num, bond_list, edges=False) + ring_len_list = [len(i) for i in cycles] + ring_size_set.update(ring_len_list) + + + # import pdb + # pdb.set_trace() + print(molecule_grapher.feature_name) grapher_features= {'feature_size':molecule_grapher.feature_size, 'feature_name': molecule_grapher.feature_name} #mol_wrapper_dict[mol.entry_id] = mol_wrapper @@ -297,7 +334,6 @@ def collapse_isomorphism_group(g): # print(f"mean: {scaler._mean}") # print(f"std: {scaler._std}") - # Create a dictionary where key is mol.entry_id and value is a normalized dgl molecule graph for key in dgl_molecules_dict.keys(): temp_index = dgl_molecules_dict[key] @@ -306,7 +342,16 @@ def collapse_isomorphism_group(g): #ADD WRITING MOLECULE LMDB HERE!!! - + #wx, dump molecule lmdb. + dump_molecule_lmdb( + molecule_ind_list, + dgl_molecules, #dgl_graphs + pmg_objects, + charge_set, + ring_size_set, + element_set, + mol_lmdb_path + ) log_message("creating molecule entry pickle") # ideally we would serialize mol_entries to a json @@ -322,6 +367,7 @@ def collapse_isomorphism_group(g): with open(grapher_features_pickle_location, "wb") as f: pickle.dump(grapher_features, f) + log_message("species filtering finished. " + str(len(mol_entries)) + " species") return mol_entries, dgl_molecules_dict @@ -355,4 +401,10 @@ def add_electron_species( mol_entries.append(electron_entry) with open(mol_entries_pickle_location, "wb") as f: pickle.dump(mol_entries, f) - return mol_entries \ No newline at end of file + return mol_entries + + +def clean(input): + return "".join([i for i in input if not i.isdigit()]) +def clean_op(input): + return "".join([i for i in input if i.isdigit()]) \ No newline at end of file diff --git a/run_network_generation.py b/run_network_generation.py index 42a055b..d42e9af 100644 --- a/run_network_generation.py +++ b/run_network_generation.py @@ -21,6 +21,9 @@ worker_payload_json = sys.argv[3] dgl_molecules_dict_pickle_file = sys.argv[4] grapher_features_dict_pickle_file = sys.argv[5] +#wx,add reaction lmdb_path +reaction_lmdb_path = sys.argv[6] + with open(mol_entries_pickle_file, 'rb') as f: mol_entries = pickle.load(f) @@ -36,7 +39,9 @@ dispatcher(mol_entries, dgl_molecules_dict_pickle_file, grapher_features_dict_pickle_file, - dispatcher_payload + dispatcher_payload, + #wx, + reaction_lmdb_path ) else: diff --git a/test.py b/test.py index 6bf38c9..8eee93b 100644 --- a/test.py +++ b/test.py @@ -592,6 +592,7 @@ def euvl_phase1_test(): "electron_free_energy": 0.0, } + #wx, dump molecule lmdbs inside species_filter function. mol_entries = species_filter( database_entries, mol_entries_pickle_location=folder + "/mol_entries.pickle", @@ -642,7 +643,7 @@ def euvl_phase1_test(): "run_network_generation.py", folder + "/mol_entries.pickle", folder + "/dispatcher_payload.json", - folder + "/worker_payload.json", + folder + "/worker_payload.json" ] ) @@ -899,7 +900,11 @@ def euvl_bondnet_test(): "electron_free_energy": 0.0, } + # import pdb + # pdb.set_trace() + mol_entries, dgl_molecules_dict = species_filter( + #wx: dump mol lmdb at the end of species filter. database_entries, mol_entries_pickle_location=folder + "/mol_entries.pickle", dgl_mol_grphs_pickle_location = folder + "/dgl_mol_graphs.pickle", @@ -909,14 +914,18 @@ def euvl_bondnet_test(): coordimer_weight=lambda mol: (mol.get_free_energy(params["temperature"])), species_logging_decision_tree=species_decision_tree, generate_unfiltered_mol_pictures=False, + mol_lmdb_path = folder + "/mol.lmdb", ) + print(len(mol_entries), "initial mol entries") bucket(mol_entries, folder + "/buckets.sqlite") print(len(mol_entries), "final mol entries") +#first test terminates here. + dispatcher_payload = DispatcherPayload( folder + "/buckets.sqlite", folder + "/rn.sqlite", @@ -936,7 +945,7 @@ def euvl_bondnet_test(): subprocess.run( [ - "mpirun", + "mpirun", #call mpi. "--use-hwthread-cpus", "-n", number_of_threads, @@ -946,7 +955,10 @@ def euvl_bondnet_test(): folder + "/dispatcher_payload.json", folder + "/worker_payload.json", folder + "/dgl_mol_graphs.pickle", - folder + "/grapher_features.pickle" + folder + "/grapher_features.pickle", + + #wx, path to write reaction lmdb + folder + "/reaction.lmdb" ] ) From 309599c9d65b3a0484e3f33c6b7fcdcbf40322c2 Mon Sep 17 00:00:00 2001 From: Wenbin Xu Date: Thu, 7 Dec 2023 11:50:00 -0800 Subject: [PATCH 2/3] HiPRGen-BonDNet lmdb dispatcher v2 --- HiPRGen/.reaction_filter.py.swp | Bin 16384 -> 0 bytes HiPRGen/.reaction_filter_new.py.swp | Bin 16384 -> 0 bytes 2 files changed, 0 insertions(+), 0 deletions(-) delete mode 100644 HiPRGen/.reaction_filter.py.swp delete mode 100644 HiPRGen/.reaction_filter_new.py.swp diff --git a/HiPRGen/.reaction_filter.py.swp b/HiPRGen/.reaction_filter.py.swp deleted file mode 100644 index cbcbd2540f6c0974107d1a92d71c7a6bac7ccdb8..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 16384 zcmeHNO>87b6|TUBzYqc>2rgMF4$jOh!>reKl;Fr(#&*XZc4lnN%*JdMt5wroGgE9& zcW+nEj=e;nhzk)xAqfYBga9ECMIZ%`B5~jX2qA6>2u>Uk2q~O60V3h6>Yw@X;#f#r zqH6imxT{{hdiB+-SFfwPyYyJCK_9BD5?t>mJFZb^uWX%1p@DE3->ZHqG zT6%D~;ze4=34&k96rOik%+BwKz1BAOKd}+`TTLEk_va(doLB_D-4kBSqe^&mf=*Lq z7%&W+mVrCT%JPLJC4K+Gee~`RUOsI@BWD;e3>XFs1BL;^fMLKeU>GnAyagDD!!zWI z7{}Yvv1HHMx25;l?;GjmrR6v0>R(UmpG?dDJXinr^!~l+Z$DL>=}(Kzk72+tU>GnA z7zPXjh5^HXVZbn87%&VN1`GrL2L_l+$RhMEssjML|IeQP|9Tf8F9BZyn!rWi9^m!& z5%L${r@(iCXMnE)PXi9n04@Qa25Nu>oCWR${&6QEzXQGxoCWR${*42HUjg3(Ubq8w zzyWXt_zbWQJOX?a_}hC3`5Evs@MGXdz&8O21b`1b4BQ6%>D`3<8h8rW2Nr?9zl)Hc z1J3{lKm~Z?orJssd>i-%@O9vM;5pz4;FG|Ez$)+&;1BO0p+t0BDQs94e>6UCz;`DO2r9$pF45U6*6S8GvJXO zvLi2GZY5RfcGV{*Jn3j0MAz)D60|)bV<@_Mpe+CgL)N~Y+z4mxg zxI`R}NlEuXGLPCZgMV25Xb2KX)7rONR&}#e+ilwQ-Rhm#;S=WQnZL7IuH>`QB7fxk4 zI*Iu85r;d^D9vZ5s!5-k%_Zm!eD5fsKJXoWs*aSPnwldzU^h7J317$ocM&KMJD4YK zVqEwJg{FRm!Sy8}TebF{V##XRt!ncU$Wq;RcDHudl=qGtwPci9$!XsUy3C_DSR@oN zl6HBFkHz4R3kwU{m9(*wP=z=G zirKAN$TN?yg`1Ck-YUA+sx_*uEA*0eCE;xo>~}2MsczJ*DR(Rq2l`^3`Xi z>l-U78Vf8c6-<2|t3ZUH$y_+jSa=@R7UhQsnnSK4fftM`SeGKELQ;Wc6vf~ijai&` zt1L&YQi!c&$f6^9jU%_r7U%TJ8SaiVMwCBJGReurKa&PB z#J+&Q@ZaWfhS%JjoYt&lhSg-GYWPmza+M&(1Ms2J3&o4erRe(zh$m3XabR+G&FL=MdM1Yzx)1A^Pq{CXNQ@bGN1-Pi z_!mYM_ysMFCO%j3K}G!BsPRBJ1B85rsTpI}*3v$Qka8{tQI{Vk`01SF89DP5+@eJNp#Mgj(uf{R`8-N zPCmQVGHFkl!k47~Li zz-igB3#a(&w|aXsc@s6=1ZOMp`rMI!dGa=ET9lS%Plofl!e{+L<#U1Xi&yc2px7(v z*OWbVnv$PE;AP)n%#IFjZACFVkT~80&f|s^b&59ijxFzb{-5MLE<1Usq_3Q2PeP_% k5zaTy@8|TWGoj(sTf$s!{-2pwgwvG*XGQZzGE;;859QZ8Jpcdz diff --git a/HiPRGen/.reaction_filter_new.py.swp b/HiPRGen/.reaction_filter_new.py.swp deleted file mode 100644 index 05f29ca899322cef97a635f3c24942d92ccfdec9..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 16384 zcmeHNU5p!76~1lLmcNpUR-!KrH@ix0ryFNC*$TSVD(sKDapU!_yiS)iEW>!_+8#2V zap#WX-6T*Dyg-X8S_BAz03pf)LWKu@BoGoHL=g`ZDwPM2kPyEQzynWEk?@^6GuLB# zvsoefQjO(L@Kex&}E`2Ki&eNQ|(67V$eao{Py0#<=Dz#~8zC;^9ndx3x5!PrZ{bHE|sUf@4( zV(cftmw~Uo5p}>8a1MAEun9Z_ybbu<8yNdO@Gan*z&C);0ut~64|oS~6Yz)EGxk&9 zqrfI`1o-Fc82cgcG_VB}for!j_9E~F;4{D{foFkdfGfa*zyrW6a1ZeEZH)a2_$BZn z@I~M$;7Q;La2YrO90qO$4gf#96}ke?0}^-)I0qaAUOB+nzk#m-&j6nQwgC&+02YA- z;GMuB;4a{|n6MuMKLS1nd<19%Rp1zK7`P3%1-Kdb3+DD0z>`1|SOOLSHRq8}aeM+H z`8;p*tP3JEtX@zQ+rsLFB0sTFs+5-M{3w68wzkS!f#2&Imf!P2IZ6#JLK?PXg@Zbj8^Q|%2Oaq2 z46{X(4~EL~NyqEz6B zM^@*KY~-f6>szL)N(%+F$!+B#b-ID$g?Xic>$i-KkkV|4{OAyK@E#iB;c~USak{h& zb|VsMXRFn6bp`h`ssZZj<%LT5v8ZsGS*|H1-^5gjU?Vi)WhJlI!BntnY162c7MAMe zwW?8BTUm*^%(6)?!47ZiaGGHeNFMq;Y>N~}yx$g{o<-ip7;tJ>&~m&Mw}s_MyjaA; zK!}lnmKUMx$^U%}yD=*Mbg5P{${WT4`Pms`aiP9+TJbZmV=%d7N@A6=!4|G_QD7EY z+~jxziic*yh0+_P>awv~8kIV;FU>JXARg}tEQ%qadUh5R>8aVAhFsrscOvS2&l39@ zh#Jz=?C`dEQSheYIkGKmj%(5i}y{dWdp$qg1Q&a<#suolf#c7b<5<8$5q% zitqI;FvpW32#qA@B+|9`!b8dstS|L(B%0d(c5+Ww3|(3&)p&KS&Z}oD6_&2FR}cmR zgp*)Uij`H-tp}lp8rUm?X34KTng>|4Nweg-NAqsr+n7P)MA8=5QSFEINqWreyN6=Z z^o9o6dkozpHOSsQ5TzQ3?0(aT^^-TUW6&cRP2Rxr#3SID+V)Yc_zZ zye)zKdWqK;7AvKcJLZ`Md7;pLCJhx!Vt2#(#)=C14i*&!Q%{5xh!8YQ8;;X=To-GL zi){qWj-W{3`u!r-r4w>TatF&O2*Eizo1q?-G)FB7#G|BR20Q$Mz&o7D;!LlcjtEV9 zA#`Zr=+&WXN7mlR;us+r%X*w-rpd&QK?m8za{+QV6WW;`7oLnFF=pKEbzR4Tf59r>%ckGyh$wtdi8!hFT@yp)DL~lcFL)0>jLI}?%5;oQDZav!JE1AoB;~$q!vE7BnZ;Z#; zK@y!HS~c%rKOYAy-gJFGh>8V5K~hYc(5RF)#Yw5&B~R&H^c1HCEY>~FR7SO;&}hW& z<_C$inxOKPXo)0BAkhgz-*>gnT@iG8*ccN7ku@4T> z=65?h9UKIwyCAg*=SNQ5rUr+}>889ed@!Tya_{zX#rOG+mp_m5xqP$0o7m_k-QNkwlDs{!b3GA7>c#VK%DE z*!_qbCb|y$;kc{cNY1P7hG_*xt;RVa{y!UcmE7P2EjbCz@j31aL<2_!0!N){M42n{ zTEqo@c%C1d*($uBrAMUSDQSu|oG8lO#OcZ*dE=Dk(mEx-cWI{O(O~*!xS=XW9u6ih zYsD$jMXQTc)fu`hkkORQGBc@ws933d{1@i*RPxvB%r5IPZZuQGaT0~Z keBk%bPj5}p1$+)KDpCyZ;}d$sAyjQ3k-38EdT%j From 8adf0b4cc25a8f121e72395eae7b7bfb0a379e46 Mon Sep 17 00:00:00 2001 From: Wenbin Xu Date: Thu, 7 Dec 2023 17:54:03 -0800 Subject: [PATCH 3/3] HiPRGen-BonDNet lmdb worker level --- HiPRGen/lmdb_dataset.py | 189 ++++++++++++---- HiPRGen/reaction_filter.py | 81 ++++--- HiPRGen/reaction_filter_cp.py | 396 ++++++++++++++++++++++++++++++++++ HiPRGen/rxn_networks_graph.py | 4 +- run_network_generation.py | 24 ++- test.py | 22 +- 6 files changed, 637 insertions(+), 79 deletions(-) create mode 100644 HiPRGen/reaction_filter_cp.py diff --git a/HiPRGen/lmdb_dataset.py b/HiPRGen/lmdb_dataset.py index 8a7111d..4034c1c 100644 --- a/HiPRGen/lmdb_dataset.py +++ b/HiPRGen/lmdb_dataset.py @@ -31,22 +31,50 @@ def __init__(self, config, transform=None): self.config = config self.path = Path(self.config["src"]) - # Get metadata in case - # self.metadata_path = self.path.parent / "metadata.npz" - self.env = self.connect_db(self.path) - - # If "length" encoded as ascii is present, use that - # If there are additional properties, there must be length. - length_entry = self.env.begin().get("length".encode("ascii")) - if length_entry is not None: - num_entries = pickle.loads(length_entry) + if not self.path.is_file(): + db_paths = sorted(self.path.glob("*.lmdb")) + assert len(db_paths) > 0, f"No LMDBs found in '{self.path}'" + #self.metadata_path = self.path / "metadata.npz" + + self._keys = [] + self.envs = [] + for db_path in db_paths: + cur_env = self.connect_db(db_path) + self.envs.append(cur_env) + + # If "length" encoded as ascii is present, use that + length_entry = cur_env.begin().get("length".encode("ascii")) + if length_entry is not None: + num_entries = pickle.loads(length_entry) + else: + # Get the number of stores data from the number of entries in the LMDB + num_entries = cur_env.stat()["entries"] + + # Append the keys (0->num_entries) as a list + self._keys.append(list(range(num_entries))) + + keylens = [len(k) for k in self._keys] + self._keylen_cumulative = np.cumsum(keylens).tolist() + self.num_samples = sum(keylens) + + else: - # Get the number of stores data from the number of entries - # in the LMDB - num_entries = self.env.stat()["entries"] - - self._keys = list(range(num_entries)) - self.num_samples = num_entries + # Get metadata in case + # self.metadata_path = self.path.parent / "metadata.npz" + self.env = self.connect_db(self.path) + + # If "length" encoded as ascii is present, use that + # If there are additional properties, there must be length. + length_entry = self.env.begin().get("length".encode("ascii")) + if length_entry is not None: + num_entries = pickle.loads(length_entry) + else: + # Get the number of stores data from the number of entries + # in the LMDB + num_entries = self.env.stat()["entries"] + + self._keys = list(range(num_entries)) + self.num_samples = num_entries # Get portion of total dataset self.sharded = False @@ -71,15 +99,34 @@ def __getitem__(self, idx): # if sharding, remap idx to appropriate idx of the sharded set if self.sharded: idx = self.available_indices[idx] + + if not self.path.is_file(): + # Figure out which db this should be indexed from. + db_idx = bisect.bisect(self._keylen_cumulative, idx) + # Extract index of element within that db. + el_idx = idx + if db_idx != 0: + el_idx = idx - self._keylen_cumulative[db_idx - 1] + assert el_idx >= 0 + + # Return features. + datapoint_pickled = ( + self.envs[db_idx] + .begin() + .get(f"{self._keys[db_idx][el_idx]}".encode("ascii")) + ) + data_object = pickle.loads(datapoint_pickled) + #data_object.id = f"{db_idx}_{el_idx}" + + else: + #!CHECK, _keys should be less then total numbers of keys as there are more properties. + datapoint_pickled = self.env.begin().get(f"{self._keys[idx]}".encode("ascii")) - #!CHECK, _keys should be less then total numbers of keys as there are more properties. - datapoint_pickled = self.env.begin().get(f"{self._keys[idx]}".encode("ascii")) - - data_object = pickle.loads(datapoint_pickled) + data_object = pickle.loads(datapoint_pickled) - # TODO - if self.transform is not None: - data_object = self.transform(data_object) + # TODO + if self.transform is not None: + data_object = self.transform(data_object) return data_object @@ -109,25 +156,30 @@ def get_metadata(self, num_samples=100): class LmdbMoleculeDataset(LmdbBaseDataset): def __init__(self, config, transform=None): super(LmdbMoleculeDataset, self).__init__(config=config, transform=transform) - + if not self.path.is_file(): + self.env_ = self.envs[0] + raise("Not Implemented Yet") + + else: + self.env_ = self.env @property def charges(self): - charges = self.env.begin().get("charges".encode("ascii")) + charges = self.env_.begin().get("charges".encode("ascii")) return pickle.loads(charges) @property def ring_sizes(self): - ring_sizes = self.env.begin().get("ring_sizes".encode("ascii")) + ring_sizes = self.env_.begin().get("ring_sizes".encode("ascii")) return pickle.loads(ring_sizes) @property def elements(self): - elements = self.env.begin().get("elements".encode("ascii")) + elements = self.env_.begin().get("elements".encode("ascii")) return pickle.loads(elements) @property def feature_info(self): - feature_info = self.env.begin().get("feature_info".encode("ascii")) + feature_info = self.env_.begin().get("feature_info".encode("ascii")) return pickle.loads(feature_info) @@ -135,30 +187,83 @@ class LmdbReactionDataset(LmdbBaseDataset): def __init__(self, config, transform=None): super(LmdbReactionDataset, self).__init__(config=config, transform=transform) + if not self.path.is_file(): + self.env_ = self.envs[0] + #get keys + for i in range(1, len(self.envs)): + for key in ["feature_size", "dtype", "feature_name"]: #, "mean", "std"]: + assert self.envs[i].begin().get(key.encode("ascii")) == self.envs[0].begin().get(key.encode("ascii")) + #! mean and std are not equal across different dataset at this time. + #get mean and std + mean_list = [pickle.loads(self.envs[i].begin().get("mean".encode("ascii"))) for i in range(0, len(self.envs))] + std_list = [pickle.loads(self.envs[i].begin().get("std".encode("ascii"))) for i in range(0, len(self.envs))] + count_list = [pickle.loads(self.envs[i].begin().get("length".encode("ascii"))) for i in range(0, len(self.envs))] + self._mean, self._std = combined_mean_std(mean_list, std_list, count_list) + + else: + self.env_ = self.env + self._mean = pickle.loads(self.env_.begin().get("mean".encode("ascii"))) + self._std = pickle.loads(self.env_.begin().get("std".encode("ascii"))) + @property def dtype(self): - dtype = self.env.begin().get("dtype".encode("ascii")) + dtype = self.env_.begin().get("dtype".encode("ascii")) return pickle.loads(dtype) - + @property def feature_size(self): - feature_size = self.env.begin().get("feature_size".encode("ascii")) + feature_size = self.env_.begin().get("feature_size".encode("ascii")) return pickle.loads(feature_size) @property def feature_name(self): - feature_name = self.env.begin().get("feature_name".encode("ascii")) + feature_name = self.env_.begin().get("feature_name".encode("ascii")) return pickle.loads(feature_name) - + @property def mean(self): - mean = self.env.begin().get("mean".encode("ascii")) - return pickle.loads(mean) - + return self._mean + @property def std(self): - std = self.env.begin().get("std".encode("ascii")) - return pickle.loads(std) + #std = self.env_.begin().get("std".encode("ascii")) + return self._std + +# @property +# def mean(self): +# mean = self.env_.begin().get("mean".encode("ascii")) +# return pickle.loads(mean) + +# @property +# def std(self): +# std = self.env_.begin().get("std".encode("ascii")) +# return pickle.loads(std) + + +def combined_mean_std(mean_list, std_list, count_list): + """ + Calculate the combined mean and standard deviation of multiple datasets. + + :param mean_list: List of means of the datasets. + :param std_list: List of standard deviations of the datasets. + :param count_list: List of number of data points in each dataset. + :return: Combined mean and standard deviation. + """ + # Calculate total number of data points + total_count = sum(count_list) + + # Calculate combined mean + combined_mean = sum(mean * count for mean, count in zip(mean_list, count_list)) / total_count + + # Calculate combined variance + combined_variance = sum( + ((std ** 2) * (count - 1) + count * (mean - combined_mean) ** 2 for mean, std, count in zip(mean_list, std_list, count_list)) + ) / (total_count - len(mean_list)) + + # Calculate combined standard deviation + combined_std = (combined_variance ** 0.5) + + return combined_mean, combined_std @@ -442,10 +547,10 @@ def write_to_lmdb(new_samples, current_length, lmdb_update, db_path): map_async=True, ) - pbar = tqdm( - total=len(new_samples), - desc=f"Adding new samples into LMDBs", - ) + # pbar = tqdm( + # total=len(new_samples), + # desc=f"Adding new samples into LMDBs", + # ) #write indexed samples idx = current_length @@ -456,7 +561,7 @@ def write_to_lmdb(new_samples, current_length, lmdb_update, db_path): pickle.dumps(sample, protocol=-1), ) idx += 1 - pbar.update(1) + #pbar.update(1) txn.commit() #write properties diff --git a/HiPRGen/reaction_filter.py b/HiPRGen/reaction_filter.py index 53d6d83..914e6d9 100644 --- a/HiPRGen/reaction_filter.py +++ b/HiPRGen/reaction_filter.py @@ -104,12 +104,13 @@ def log_message(*args, **kwargs): '[' + strftime('%H:%M:%S', localtime()) + ']', *args, **kwargs) +#restructure input of dispatcher def dispatcher( - mol_entries, - dgl_molecules_dict, - grapher_features, - dispatcher_payload, - reaction_lmdb_path + mol_entries, #1 + #dgl_molecules_dict, + #grapher_features, + dispatcher_payload, #2 + #reaction_lmdb_path ): comm = MPI.COMM_WORLD @@ -138,16 +139,17 @@ def dispatcher( #### HY ## initialize preprocess data - -#wx: writting lmdbs in dispatcher ? - rxn_networks_g = rxn_networks_graph( - mol_entries, - dgl_molecules_dict, - grapher_features, - dispatcher_payload.bondnet_test, - reaction_lmdb_path - ) - #### + +# #wx: writting lmdbs in dispatcher ? +# #wx: each worker needs to initlize rxn_networks_graph at worker level. +# rxn_networks_g = rxn_networks_graph( +# mol_entries, +# dgl_molecules_dict, +# grapher_features, +# dispatcher_payload.bondnet_test, +# reaction_lmdb_path #wx. different +# ) +# #### log_message("initializing report generator") @@ -162,6 +164,7 @@ def dispatcher( worker_states = {} worker_ranks = [i for i in range(comm.Get_size()) if i != DISPATCHER_RANK] + print("worker_states",worker_states) for i in worker_ranks: worker_states[i] = WorkerState.INITIALIZING @@ -173,6 +176,7 @@ def dispatcher( log_message("all workers running") + #global index which is different with local index in worker reaction_index = 0 log_message("handling requests") @@ -209,6 +213,7 @@ def dispatcher( tag = status.Get_tag() rank = status.Get_source() + #this is the last step when worker is out of work. if tag == SEND_ME_A_WORK_BATCH: if len(work_batch_list) == 0: comm.send(None, dest=rank, tag=HERE_IS_A_WORK_BATCH) @@ -224,9 +229,11 @@ def dispatcher( ": group ids:", group_id_0, group_id_1 ) - - - elif tag == NEW_REACTION_DB: + #this is where worker is doing things. found a good reation and send to dispatcher. + #if this is correct, then create_rxn_networks_graph operates on worker instead of dispatcher. + #This is the reason why adding samples one by one. QA: where is batch of reactions? +#ten reactions, first filter out, second send , next eight + elif tag == NEW_REACTION_DB: reaction = data rn_cur.execute( insert_reaction, @@ -243,8 +250,10 @@ def dispatcher( reaction['is_redox'] )) - # Create reaction graph + add to LMDB - rxn_networks_g.create_rxn_networks_graph(reaction, reaction_index) + # # # Create reaction graph + add to LMDB + # rxn_networks_g.create_rxn_networks_graph(reaction, reaction_index) #wx in worker level + +#dispatch tracks global index, worker tracks local index in that batch. reaction_index += 1 if reaction_index % dispatcher_payload.commit_frequency == 0: @@ -278,17 +287,37 @@ def dispatcher( def worker( - mol_entries, - worker_payload + mol_entries, #input of worker + worker_payload, + dgl_molecules_dict, + grapher_features, + reaction_lmdb_path + ): + # import pdb + # pdb.set_trace() + + local_reaction_idx = 0 #wx add local_idx + comm = MPI.COMM_WORLD con = sqlite3.connect(worker_payload.bucket_db_file) cur = con.cursor() - comm.send(None, dest=DISPATCHER_RANK, tag=INITIALIZATION_FINISHED) + rank = comm.Get_rank() #get id of that worker + + lmdb_name_i = reaction_lmdb_path.split(".lmdb")[0] + "_" + str(rank) + ".lmdb" + + rxn_networks_g = rxn_networks_graph( + mol_entries, + dgl_molecules_dict, + grapher_features, + #dispatcher_payload.bondnet_test, #can be removed + lmdb_name_i #wx. different + ) + while True: comm.send(None, dest=DISPATCHER_RANK, tag=SEND_ME_A_WORK_BATCH) work_batch = comm.recv(source=DISPATCHER_RANK, tag=HERE_IS_A_WORK_BATCH) @@ -355,6 +384,10 @@ def worker( dest=DISPATCHER_RANK, tag=NEW_REACTION_DB) + #comm.send, send reaction to dispatchers. + + rxn_networks_g.create_rxn_networks_graph(reaction, local_reaction_idx) + local_reaction_idx+=1 if run_decision_tree(reaction, @@ -368,4 +401,4 @@ def worker( ), dest=DISPATCHER_RANK, - tag=NEW_REACTION_LOGGING) \ No newline at end of file + tag=NEW_REACTION_LOGGING) diff --git a/HiPRGen/reaction_filter_cp.py b/HiPRGen/reaction_filter_cp.py new file mode 100644 index 0000000..c5bd263 --- /dev/null +++ b/HiPRGen/reaction_filter_cp.py @@ -0,0 +1,396 @@ +from mpi4py import MPI +from HiPRGen.rxn_networks_graph import rxn_networks_graph +from itertools import permutations, product +from HiPRGen.report_generator import ReportGenerator +import sqlite3 +from time import localtime, strftime, time +from enum import Enum +from math import floor +from HiPRGen.reaction_filter_payloads import ( + DispatcherPayload, + WorkerPayload +) + +from HiPRGen.reaction_questions import ( + run_decision_tree +) + + +""" +Phases 3 & 4 run in parallel using MPI + +Phase 3: reaction gen and filtering +input: a bucket labeled by atom count +output: a list of reactions from that bucket +description: Loop through all possible reactions in the bucket and apply the decision tree. This will run in parallel over each bucket. + +Phase 4: collating and indexing +input: all the outputs of phase 3 as they are generated +output: reaction network database +description: the worker processes from phase 3 are sending their reactions to this phase and it is writing them to DB as it gets them. We can ensure that duplicates don't get generated in phase 3 which means we don't need extra index tables on the db. + +the code in this file is designed to run on a compute cluster using MPI. +""" + + +create_metadata_table = """ + CREATE TABLE metadata ( + number_of_species INTEGER NOT NULL, + number_of_reactions INTEGER NOT NULL + ); +""" + +insert_metadata = """ + INSERT INTO metadata VALUES (?, ?) +""" + +# it is important that reaction_id is the primary key +# otherwise the network loader will be extremely slow. +create_reactions_table = """ + CREATE TABLE reactions ( + reaction_id INTEGER NOT NULL PRIMARY KEY, + number_of_reactants INTEGER NOT NULL, + number_of_products INTEGER NOT NULL, + reactant_1 INTEGER NOT NULL, + reactant_2 INTEGER NOT NULL, + product_1 INTEGER NOT NULL, + product_2 INTEGER NOT NULL, + rate REAL NOT NULL, + dG REAL NOT NULL, + dG_barrier REAL NOT NULL, + is_redox INTEGER NOT NULL + ); +""" + + +insert_reaction = """ + INSERT INTO reactions VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) +""" + +get_complex_group_sql = """ + SELECT * FROM complexes WHERE composition_id=? AND group_id=? +""" + + +# TODO: structure these global variables better +DISPATCHER_RANK = 0 + +# message tags + +# sent by workers to the dispatcher once they have finished initializing +# only sent once +INITIALIZATION_FINISHED = 0 + +# sent by workers to the dispatcher to request a new table +SEND_ME_A_WORK_BATCH = 1 + +# sent by dispatcher to workers when delivering a new table +HERE_IS_A_WORK_BATCH = 2 + +# sent by workers to the dispatcher when reaction passes db decision tree +NEW_REACTION_DB = 3 + +# sent by workers to the dispatcher when reaction passes logging decision tree +NEW_REACTION_LOGGING = 4 + +class WorkerState(Enum): + INITIALIZING = 0 + RUNNING = 1 + FINISHED = 2 + + +def log_message(*args, **kwargs): + print( + '[' + strftime('%H:%M:%S', localtime()) + ']', + *args, **kwargs) + +def dispatcher( #input of dispatcher. + mol_entries, #1 + dgl_molecules_dict, + grapher_features, + dispatcher_payload, #2 + #wx + reaction_lmdb_path +): + + comm = MPI.COMM_WORLD + work_batch_list = [] + bucket_con = sqlite3.connect(dispatcher_payload.bucket_db_file) + bucket_cur = bucket_con.cursor() + size_cur = bucket_con.cursor() + + res = bucket_cur.execute("SELECT * FROM group_counts") + for (composition_id, count) in res: + for (i,j) in product(range(count), repeat=2): + work_batch_list.append( + (composition_id, i, j)) + + composition_names = {} + res = bucket_cur.execute("SELECT * FROM compositions") + for (composition_id, composition) in res: + composition_names[composition_id] = composition + + log_message("creating reaction network db") + rn_con = sqlite3.connect(dispatcher_payload.reaction_network_db_file) + rn_cur = rn_con.cursor() + rn_cur.execute(create_metadata_table) + rn_cur.execute(create_reactions_table) + rn_con.commit() + + #### HY + ## initialize preprocess data + +#wx: writting lmdbs in dispatcher ? +#wx: each worker needs to initlize rxn_networks_graph at worker level. + rxn_networks_g = rxn_networks_graph( + mol_entries, + dgl_molecules_dict, + grapher_features, + dispatcher_payload.bondnet_test, + reaction_lmdb_path #wx. different + ) + #### + + log_message("initializing report generator") + + # since MPI processes spin lock, we don't want to have the dispathcer + # spend a bunch of time generating molecule pictures + report_generator = ReportGenerator( + mol_entries, + dispatcher_payload.report_file, + rebuild_mol_pictures=False + ) + + worker_states = {} + + worker_ranks = [i for i in range(comm.Get_size()) if i != DISPATCHER_RANK] + + for i in worker_ranks: + worker_states[i] = WorkerState.INITIALIZING + + for i in worker_states: + # block, waiting for workers to initialize + comm.recv(source=i, tag=INITIALIZATION_FINISHED) + worker_states[i] = WorkerState.RUNNING + + log_message("all workers running") + + reaction_index = 0 + + log_message("handling requests") + + batches_left_at_last_checkpoint = len(work_batch_list) + last_checkpoint_time = floor(time()) + while True: + if WorkerState.RUNNING not in worker_states.values(): + break + + current_time = floor(time()) + time_diff = current_time - last_checkpoint_time + if ( current_time % dispatcher_payload.checkpoint_interval == 0 and + time_diff > 0): + batches_left_at_current_checkpoint = len(work_batch_list) + batch_count_diff = ( + batches_left_at_last_checkpoint - + batches_left_at_current_checkpoint) + + batch_consumption_rate = batch_count_diff / time_diff + + log_message("batches remaining:", batches_left_at_current_checkpoint) + log_message("batch consumption rate:", + batch_consumption_rate, + "batches per second") + + + batches_left_at_last_checkpoint = batches_left_at_current_checkpoint + last_checkpoint_time = current_time + + + status = MPI.Status() + data = comm.recv(source=MPI.ANY_SOURCE, tag=MPI.ANY_TAG, status=status) + tag = status.Get_tag() + rank = status.Get_source() + + #this is the last step when worker is out of work. + if tag == SEND_ME_A_WORK_BATCH: + if len(work_batch_list) == 0: + comm.send(None, dest=rank, tag=HERE_IS_A_WORK_BATCH) + worker_states[rank] = WorkerState.FINISHED + else: + # pop removes and returns the last item in the list + work_batch = work_batch_list.pop() + comm.send(work_batch, dest=rank, tag=HERE_IS_A_WORK_BATCH) + composition_id, group_id_0, group_id_1 = work_batch + log_message( + "dispatched", + composition_names[composition_id], + ": group ids:", + group_id_0, group_id_1 + ) + #this is where worker is doing things. found a good reation and send to dispatcher. + #if this is correct, then create_rxn_networks_graph operates on worker instead of dispatcher. + #This is the reason why adding samples one by one. QA: where is batch of reactions? +#ten reactions, first filter out, second send , next eight + elif tag == NEW_REACTION_DB: + reaction = data + rn_cur.execute( + insert_reaction, + (reaction_index, + reaction['number_of_reactants'], + reaction['number_of_products'], + reaction['reactants'][0], + reaction['reactants'][1], + reaction['products'][0], + reaction['products'][1], + reaction['rate'], + reaction['dG'], + reaction['dG_barrier'], + reaction['is_redox'] + )) + + # # Create reaction graph + add to LMDB + rxn_networks_g.create_rxn_networks_graph(reaction, reaction_index) #wx in worker level + +#dispatch tracks global index, worker tracks local index in that batch. + + reaction_index += 1 + if reaction_index % dispatcher_payload.commit_frequency == 0: + rn_con.commit() + + + elif tag == NEW_REACTION_LOGGING: + + reaction = data[0] + decision_path = data[1] + + report_generator.emit_verbatim(decision_path) + report_generator.emit_reaction(reaction) + report_generator.emit_bond_breakage(reaction) + report_generator.emit_newline() + + + + log_message("finalzing database and generation report") + rn_cur.execute( + insert_metadata, + (len(mol_entries), + reaction_index) + ) + + + report_generator.finished() + rn_con.commit() + bucket_con.close() + rn_con.close() + + +def worker( + mol_entries, #input of worker + worker_payload +): + +#wx + local_reaction_idx = 0 + + comm = MPI.COMM_WORLD + con = sqlite3.connect(worker_payload.bucket_db_file) + cur = con.cursor() + + comm.send(None, dest=DISPATCHER_RANK, tag=INITIALIZATION_FINISHED) + +#wx + rank = comm.Get_rank() #id of that worker + + rxn_networks_g = rxn_networks_graph( + mol_entries, + dgl_molecules_dict, + grapher_features, + #dispatcher_payload.bondnet_test, + reaction_lmdb_path + rank #wx. different + ) + + + while True: + comm.send(None, dest=DISPATCHER_RANK, tag=SEND_ME_A_WORK_BATCH) + work_batch = comm.recv(source=DISPATCHER_RANK, tag=HERE_IS_A_WORK_BATCH) + + if work_batch is None: + break + + + composition_id, group_id_0, group_id_1 = work_batch + + + if group_id_0 == group_id_1: + + res = cur.execute( + get_complex_group_sql, + (composition_id, group_id_0)) + + bucket = [] + for row in res: + bucket.append((row[0],row[1])) + + iterator = permutations(bucket, r=2) + + else: + + res_0 = cur.execute( + get_complex_group_sql, + (composition_id, group_id_0)) + + bucket_0 = [] + for row in res_0: + bucket_0.append((row[0],row[1])) + + res_1 = cur.execute( + get_complex_group_sql, + (composition_id, group_id_1)) + + bucket_1 = [] + for row in res_1: + bucket_1.append((row[0],row[1])) + + iterator = product(bucket_0, bucket_1) + + + + for (reactants, products) in iterator: + reaction = { + 'reactants' : reactants, + 'products' : products, + 'number_of_reactants' : len([i for i in reactants if i != -1]), + 'number_of_products' : len([i for i in products if i != -1])} + + + decision_pathway = [] + if run_decision_tree(reaction, + mol_entries, + worker_payload.params, + worker_payload.reaction_decision_tree, + decision_pathway + ): + + comm.send( + reaction, + dest=DISPATCHER_RANK, + tag=NEW_REACTION_DB) + + #comm.send, send reaction to dispatchers. + + rxn_networks_g.create_rxn_networks_graph(reaction, local_reaction_idx) + local_reaction_idx+=1 + + + if run_decision_tree(reaction, + mol_entries, + worker_payload.params, + worker_payload.logging_decision_tree): + + comm.send( + (reaction, + '\n'.join([str(f) for f in decision_pathway]) + ), + + dest=DISPATCHER_RANK, + tag=NEW_REACTION_LOGGING) diff --git a/HiPRGen/rxn_networks_graph.py b/HiPRGen/rxn_networks_graph.py index eb7721c..c6d1949 100644 --- a/HiPRGen/rxn_networks_graph.py +++ b/HiPRGen/rxn_networks_graph.py @@ -19,14 +19,14 @@ def __init__( mol_entries, dgl_molecules_dict, grapher_features, - report_file_path, + #report_file_path, reaction_lmdb_path #wx ): #wx, which one should come from molecule lmdbs? self.mol_entries = mol_entries self.dgl_mol_dict = dgl_molecules_dict self.grapher_features = grapher_features - self.report_file_path = report_file_path + #self.report_file_path = report_file_path self.reaction_lmdb_path = reaction_lmdb_path diff --git a/run_network_generation.py b/run_network_generation.py index d42e9af..d076837 100644 --- a/run_network_generation.py +++ b/run_network_generation.py @@ -36,16 +36,22 @@ if rank == DISPATCHER_RANK: dispatcher_payload = loadfn(dispatcher_payload_json) - dispatcher(mol_entries, - dgl_molecules_dict_pickle_file, - grapher_features_dict_pickle_file, - dispatcher_payload, - #wx, - reaction_lmdb_path - ) - + dispatcher(mol_entries, + dispatcher_payload) + + #move to worker level + # dispatcher(mol_entries, + # dgl_molecules_dict_pickle_file, + # grapher_features_dict_pickle_file, + # dispatcher_payload, + # #wx, + # reaction_lmdb_path + # ) else: worker_payload = loadfn(worker_payload_json) worker(mol_entries, - worker_payload + worker_payload, + dgl_molecules_dict_pickle_file, + grapher_features_dict_pickle_file, + reaction_lmdb_path ) diff --git a/test.py b/test.py index 8eee93b..223039b 100644 --- a/test.py +++ b/test.py @@ -878,6 +878,11 @@ def euvl_phase2_test(): return tests_passed + + + + + def euvl_bondnet_test(): start_time = time.time() @@ -890,7 +895,11 @@ def euvl_bondnet_test(): ## HY bondnet_test_json = "./scratch/euvl_phase2_test/reaction_networks_graphs" + lmdbs_path_mol = "./scratch/euvl_phase2_test/lmdbs/mol" + lmdbs_path_reaction = "./scratch/euvl_phase2_test/lmdbs/reaction" subprocess.run(["mkdir", bondnet_test_json]) + subprocess.run(["mkdir", "-p",lmdbs_path_mol]) + subprocess.run(["mkdir", "-p",lmdbs_path_reaction]) ## species_decision_tree = euvl_species_decision_tree @@ -900,8 +909,13 @@ def euvl_bondnet_test(): "electron_free_energy": 0.0, } + # with open(folder + "/mol_entries.pickle", 'rb') as f: + # mol_entries = pickle.load(f) + # import pdb # pdb.set_trace() + # with open("/global/home/users/wenbinxu/data/rep/rep/HiPRGen/test/euvl_phase2_test/mol_entries.pickle", 'rb') as f: + # mol_entries = pickle.load(f) mol_entries, dgl_molecules_dict = species_filter( #wx: dump mol lmdb at the end of species filter. @@ -914,7 +928,7 @@ def euvl_bondnet_test(): coordimer_weight=lambda mol: (mol.get_free_energy(params["temperature"])), species_logging_decision_tree=species_decision_tree, generate_unfiltered_mol_pictures=False, - mol_lmdb_path = folder + "/mol.lmdb", + mol_lmdb_path = folder + "/lmdbs/mol/mol.lmdb", ) @@ -951,14 +965,18 @@ def euvl_bondnet_test(): number_of_threads, "python", "run_network_generation.py", + # "/global/home/users/wenbinxu/data/rep/rep/HiPRGen/test/euvl_phase2_test" + "/mol_entries.pickle", folder + "/mol_entries.pickle", folder + "/dispatcher_payload.json", folder + "/worker_payload.json", + + # "/global/home/users/wenbinxu/data/rep/rep/HiPRGen/test/euvl_phase2_test" + "/dgl_mol_graphs.pickle", + # "/global/home/users/wenbinxu/data/rep/rep/HiPRGen/test/euvl_phase2_test" + "/grapher_features.pickle", folder + "/dgl_mol_graphs.pickle", folder + "/grapher_features.pickle", #wx, path to write reaction lmdb - folder + "/reaction.lmdb" + folder + "/lmdbs/reaction/reaction.lmdb" ] )