Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BonDNet #13

Merged
merged 1 commit into from
Nov 8, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 54 additions & 9 deletions HiPRGen/species_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
from pymatgen.core.sites import Site
from pymatgen.core.structure import Molecule
from pymatgen.analysis.graphs import MoleculeGraph

from bondnet.model.training_utils import get_grapher
from bondnet.core.molwrapper import MoleculeWrapper
from bondnet.data.transformers import HeteroGraphFeatureStandardScaler

from bondnet.core.molwrapper import create_wrapper_mol_from_atoms_and_bonds
from bondnet.utils import int_atom

"""
Phase 1: species filtering
Expand Down Expand Up @@ -221,34 +223,77 @@ def collapse_isomorphism_group(g):

log_message(str(len(fragment_dict.keys())) + " unique fragments found")


# Make DGL Molecule graphs via BonDNet functions
log_message("creating dgl molecule graphs")
dgl_molecules_dict = {}
dgl_molecules = []
extra_keys = []


# BONDNET EDITS # BONDNET EDITS # BONDNET EDITS # BONDNET EDITS # BONDNET EDITS
for mol in mol_entries:
# print(f"mol: {mol.mol_graph}")
molecule_grapher = get_grapher(extra_keys)

non_metal_bonds = [ (i, j) for i, j, _ in mol.covalent_graph.edges.data()]
molecule_grapher = get_grapher(
features = extra_keys,
allowed_charges=[-2,-1,0,1,2],
global_feats=["charge"],
) # same
non_metal_bonds = [ (i, j) for i, j, _ in mol.covalent_graph.edges.data()] # same

# print(f"non metal bonds: {non_metal_bonds}")
mol_wrapper = MoleculeWrapper(mol_graph = mol.mol_graph, free_energy = None, id = mol.entry_id, non_metal_bonds = non_metal_bonds)
# use create molecule wrapper instead here
#mol_wrapper = MoleculeWrapper(
# mol_graph = mol.mol_graph,
# free_energy = None, id = mol.entry_id,
# non_metal_bonds = non_metal_bonds,
# extra_keys = extra_keys
#)
#print(mol.mol_graph.molecule.sites)
#print(mol.mol_graph.graph)
#print(mol.mol_graph.graph.edges())

species = [i.specie for i in mol.mol_graph.molecule.sites]
coords = [i.coords for i in mol.mol_graph.molecule.sites]

bonds = [
[i[0], i[1]] for i in mol.mol_graph.graph.edges()
]

mol_wrapper = create_wrapper_mol_from_atoms_and_bonds(
species,
coords,
bonds,
charge = mol.charge,
functional_group=None,
identifier=mol.entry_id,
original_atom_ind=None,
original_bond_ind=None,
atom_features=None,
bond_features=None,
global_features={"charge": mol.charge}
)
mol_wrapper.nonmetal_bonds = non_metal_bonds
feature = {'charge': mol.charge}
dgl_molecule_graph = molecule_grapher.build_graph_and_featurize(mol_wrapper, extra_feats_info = feature, dataset_species = elements)
dgl_molecule_graph = molecule_grapher.build_graph_and_featurize(
mol_wrapper,
extra_feats_info = feature,
element_set = elements
)
dgl_molecules.append(dgl_molecule_graph)
for nt in ["global", "atom", "bond"]:
print(f"nt: {nt}")
fts = dgl_molecule_graph.nodes[nt].data["feat"]
print(f"features: {fts}")
dgl_molecules_dict[mol.entry_id] = mol.ind
print(molecule_grapher.feature_name)
grapher_features= {'feature_size':molecule_grapher.feature_size, 'feature_name': molecule_grapher.feature_name}
#mol_wrapper_dict[mol.entry_id] = mol_wrapper

# Normalize DGL molecule graphs
scaler = HeteroGraphFeatureStandardScaler(mean = None, std = None)
normalized_graphs = scaler(dgl_molecules)

# BONDNET EDITS # BONDNET EDITS # BONDNET EDITS # BONDNET EDITS # BONDNET EDITS # BONDNET EDITS

# print(f"mean: {scaler._mean}")
# print(f"std: {scaler._std}")

Expand Down Expand Up @@ -307,4 +352,4 @@ def add_electron_species(
mol_entries.append(electron_entry)
with open(mol_entries_pickle_location, "wb") as f:
pickle.dump(mol_entries, f)
return mol_entries
return mol_entries
Loading