diff --git a/deepchem/utils/dft_utils/__init__.py b/deepchem/utils/dft_utils/__init__.py index 70912f975b..8212fb7236 100644 --- a/deepchem/utils/dft_utils/__init__.py +++ b/deepchem/utils/dft_utils/__init__.py @@ -11,7 +11,11 @@ try: from deepchem.utils.dft_utils.hamilton.intor.lattice import Lattice - from deepchem.utils.dft_utils.datastruct import ZType + from deepchem.utils.dft_utils.data.datastruct import ZType + from deepchem.utils.dft_utils.data.datastruct import AtomPosType + from deepchem.utils.dft_utils.data.datastruct import AtomZsType + from deepchem.utils.dft_utils.data.datastruct import SpinParam + from deepchem.utils.dft_utils.data.datastruct import ValGrad from deepchem.utils.dft_utils.hamilton.orbparams import BaseOrbParams from deepchem.utils.dft_utils.hamilton.orbparams import QROrbParams diff --git a/deepchem/utils/dft_utils/api/parser.py b/deepchem/utils/dft_utils/api/parser.py index 706badb994..5ac949822f 100644 --- a/deepchem/utils/dft_utils/api/parser.py +++ b/deepchem/utils/dft_utils/api/parser.py @@ -1,6 +1,6 @@ from typing import Union, Tuple import torch -from deepchem.utils.dft_utils.datastruct import AtomZsType, AtomPosType +from deepchem.utils.dft_utils import AtomZsType, AtomPosType from deepchem.utils import get_atomz __all__ = ["parse_moldesc"] diff --git a/deepchem/utils/dft_utils/data/datastruct.py b/deepchem/utils/dft_utils/data/datastruct.py new file mode 100644 index 0000000000..395495c4e0 --- /dev/null +++ b/deepchem/utils/dft_utils/data/datastruct.py @@ -0,0 +1,135 @@ +""" +Density Functional Theory Data Structure Utilities +""" +from typing import Union, TypeVar, Generic, Optional, Callable, List +from dataclasses import dataclass +import torch +import numpy as np + +__all__ = ["ZType"] + +T = TypeVar('T') +P = TypeVar('P') + +# type of the atom Z +ZType = Union[int, float, torch.Tensor] + +# input types +AtomZsType = Union[List[str], List[ZType], torch.Tensor] +AtomPosType = Union[List[List[float]], np.ndarray, torch.Tensor] + + +@dataclass +class SpinParam(Generic[T]): + """Data structure to store different values for spin-up and spin-down electrons. + Examples + -------- + >>> import torch + >>> from deepchem.utils.dft_utils import SpinParam + >>> dens_u = torch.ones(1) + >>> dens_d = torch.zeros(1) + >>> sp = SpinParam(u=dens_u, d=dens_d) + >>> sp.u + tensor([1.]) + >>> sp.sum() + tensor([1.]) + >>> sp.reduce(torch.multiply) + tensor([0.]) + """ + + def __init__(self, u: T, d: T): + """Initialize the SpinParam object. + Parameters + ---------- + u: any type + The parameters that corresponds to the spin-up electrons. + d: any type + The parameters that corresponds to the spin-down electrons. + """ + self.u = u + self.d = d + + def __repr__(self) -> str: + """Return the string representation of the SpinParam object.""" + return f"SpinParam(u={self.u}, d={self.d})" + + def sum(self): + """Returns the sum of up and down parameters.""" + + return self.u + self.d + + def reduce(self, fcn: Callable) -> T: + """Reduce up and down parameters with the given function.""" + + return fcn(self.u, self.d) + + +@dataclass +class ValGrad: + """Data structure that contains local information about density profiles. + Data structure used as a umbrella class for density profiles and the + derivative of the potential w.r.t. density profiles. + Examples + -------- + >>> import torch + >>> from deepchem.utils.dft_utils import ValGrad + >>> dens = torch.ones(1) + >>> grad = torch.zeros(1) + >>> lapl = torch.ones(1) + >>> kin = torch.ones(1) + >>> vg = ValGrad(value=dens, grad=grad, lapl=lapl, kin=kin) + >>> vg + vg + ValGrad(value=tensor([2.]), grad=tensor([0.]), lapl=tensor([2.]), kin=tensor([2.])) + >>> vg * 5 + ValGrad(value=tensor([5.]), grad=tensor([0.]), lapl=tensor([5.]), kin=tensor([5.])) + """ + + def __init__(self, + value: torch.Tensor, + grad: Optional[torch.Tensor] = None, + lapl: Optional[torch.Tensor] = None, + kin: Optional[torch.Tensor] = None): + """Initialize the ValGrad object. + Parameters + ---------- + value: torch.Tensor + Tensors containing the value of the local information. + grad: torch.Tensor or None + If tensor, it represents the gradient of the local information with + shape ``(..., 3)`` where ``...`` should be the same shape as ``value``. + lapl: torch.Tensor or None + If tensor, represents the laplacian value of the local information. + It should have the same shape as ``value``. + kin: torch.Tensor or None + If tensor, represents the local kinetic energy density. + It should have the same shape as ``value``. + """ + self.value = value + self.grad = grad + self.lapl = lapl + self.kin = kin + + def __add__(self, b): + """Add two ValGrad objects together.""" + return ValGrad( + value=self.value + b.value, + grad=self.grad + b.grad if self.grad is not None else None, + lapl=self.lapl + b.lapl if self.lapl is not None else None, + kin=self.kin + b.kin if self.kin is not None else None, + ) + + def __mul__(self, f: Union[float, int, torch.Tensor]): + """Multiply the ValGrad object with a scalar.""" + if isinstance(f, torch.Tensor): + assert f.numel( + ) == 1, "ValGrad multiplication with tensor can only be done with 1-element tensor" + + return ValGrad( + value=self.value * f, + grad=self.grad * f if self.grad is not None else None, + lapl=self.lapl * f if self.lapl is not None else None, + kin=self.kin * f if self.kin is not None else None, + ) + + def __repr__(self): + return f"ValGrad(value={self.value}, grad={self.grad}, lapl={self.lapl}, kin={self.kin})" diff --git a/deepchem/utils/dft_utils/datastruct.py b/deepchem/utils/dft_utils/datastruct.py deleted file mode 100644 index adaee17a9b..0000000000 --- a/deepchem/utils/dft_utils/datastruct.py +++ /dev/null @@ -1,17 +0,0 @@ -""" -Density Functional Theory Data Structure Utilities -""" -from typing import Union, TypeVar, List -import numpy as np -import torch - -__all__ = ["ZType"] - -T = TypeVar('T') -P = TypeVar('P') - -# type of the atom Z -ZType = Union[int, float, torch.Tensor] -# input types -AtomZsType = Union[List[str], List[ZType], torch.Tensor] -AtomPosType = Union[List[List[float]], np.ndarray, torch.Tensor] diff --git a/deepchem/utils/test/test_dft_utils.py b/deepchem/utils/test/test_dft_utils.py index 1c911ea003..c9f9d7862f 100644 --- a/deepchem/utils/test/test_dft_utils.py +++ b/deepchem/utils/test/test_dft_utils.py @@ -105,3 +105,45 @@ def test_parse_moldesc(): torch.tensor( [[0.86625, 0.00000, 0.00000], [-0.86625, 0.00000, 0.00000]], dtype=torch.float64)) + + +@pytest.mark.torch +def test_spin_param(): + """Test SpinParam object.""" + from deepchem.utils.dft_utils import SpinParam + dens_u = torch.ones(1) + dens_d = torch.zeros(1) + sp = SpinParam(u=dens_u, d=dens_d) + + assert torch.allclose(sp.u, dens_u) + assert torch.allclose(sp.d, dens_d) + assert torch.allclose(sp.sum(), torch.tensor([1.])) + assert torch.allclose(sp.reduce(torch.multiply), torch.tensor([0.])) + + +@pytest.mark.torch +def test_val_grad(): + """Test ValGrad object.""" + from deepchem.utils.dft_utils import ValGrad + dens = torch.ones(1) + grad = torch.zeros(1) + lapl = torch.ones(1) + kin = torch.ones(1) + vg = ValGrad(value=dens, grad=grad, lapl=lapl, kin=kin) + + assert torch.allclose(vg.value, dens) + assert torch.allclose(vg.grad, grad) + assert torch.allclose(vg.lapl, lapl) + assert torch.allclose(vg.kin, kin) + + vg2 = vg + vg + assert torch.allclose(vg2.value, torch.tensor([2.])) + assert torch.allclose(vg2.grad, torch.tensor([0.])) + assert torch.allclose(vg2.lapl, torch.tensor([2.])) + assert torch.allclose(vg2.kin, torch.tensor([2.])) + + vg5 = vg * 5 + assert torch.allclose(vg5.value, torch.tensor([5.])) + assert torch.allclose(vg5.grad, torch.tensor([0.])) + assert torch.allclose(vg5.lapl, torch.tensor([5.])) + assert torch.allclose(vg5.kin, torch.tensor([5.])) diff --git a/docs/source/api_reference/utils.rst b/docs/source/api_reference/utils.rst index 80200b9053..9088210363 100644 --- a/docs/source/api_reference/utils.rst +++ b/docs/source/api_reference/utils.rst @@ -285,6 +285,12 @@ The utilites here are used to create an object that contains information about a .. autoclass:: deepchem.utils.dft_utils.Lattice :members: +.. autoclass:: deepchem.utils.dft_utils.SpinParam + :members: + +.. autoclass:: deepchem.utils.dft_utils.ValGrad + :members: + .. autoclass:: deepchem.utils.dftutils.KSCalc :members: