Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update: frontend docs #53

Merged
merged 2 commits into from
Jul 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 25 additions & 4 deletions dmff/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import linecache
import itertools
from collections import defaultdict
from typing import Dict
import xml.etree.ElementTree as ET
from copy import deepcopy
import warnings
Expand Down Expand Up @@ -1825,26 +1826,46 @@ def getParameters(self):


class HarmonicBondJaxGenerator:
def __init__(self, ff):
def __init__(self, ff:Hamiltonian):
self.name = "HarmonicBondForce"
self.ff = ff
self.fftree = ff.fftree
self.paramtree = ff.paramtree
self.ff:Hamiltonian = ff
self.fftree:ForcefieldTree = ff.fftree
self.paramtree:Dict = ff.paramtree

def extract(self):
"""
extract forcefield paramters from ForcefieldTree.
"""
lengths = self.fftree.get_attribs(f"{self.name}/Bond", "length")
# get_attribs will return a list of list.
ks = self.fftree.get_attribs(f"{self.name}/Bond", "k")
self.paramtree[self.name] = {}
self.paramtree[self.name]["length"] = jnp.array(lengths)
self.paramtree[self.name]["k"] = jnp.array(ks)

def overwrite(self):
"""
update parameters in the fftree by using paramtree of this generator.
"""
self.fftree.set_attrib(f"{self.name}/Bond", "length",
self.paramtree[self.name]["length"])
self.fftree.set_attrib(f"{self.name}/Bond", "k",
self.paramtree[self.name]["k"])

def createForce(self, sys, data, nonbondedMethod, nonbondedCutoff, args):
"""
This method will create a potential calculation kernel. It usually should do the following:

1. Match the corresponding bond parameters according to the atomic types at both ends of each bond.

2. Create a potential calculation kernel, and pass those mapped parameters to the kernel.

3. assign the jax potential to the _jaxPotential.

Args:
Those args are the same as those in createSystem.
"""

# initialize typemap
matcher = TypeMatcher(self.fftree, "HarmonicBondForce/Bond")

Expand Down
Loading