From 024a94c49725c0e92bb396d24681bbcf7c60c51f Mon Sep 17 00:00:00 2001 From: GreatRSingh Date: Fri, 19 Jan 2024 02:18:31 +0530 Subject: [PATCH 1/5] radialgrid - Partial Tests --- deepchem/utils/dft_utils/__init__.py | 9 + deepchem/utils/dft_utils/grid/radial_grid.py | 560 +++++++++++++++++++ deepchem/utils/test/test_dft_utils.py | 30 + docs/source/api_reference/utils.rst | 24 + 4 files changed, 623 insertions(+) create mode 100644 deepchem/utils/dft_utils/grid/radial_grid.py diff --git a/deepchem/utils/dft_utils/__init__.py b/deepchem/utils/dft_utils/__init__.py index 48c8f984af..81df69c7e1 100644 --- a/deepchem/utils/dft_utils/__init__.py +++ b/deepchem/utils/dft_utils/__init__.py @@ -28,6 +28,15 @@ from deepchem.utils.dft_utils.grid.base_grid import BaseGrid + from deepchem.utils.dft_utils.grid.radial_grid import RadialGrid + from deepchem.utils.dft_utils.grid.radial_grid import get_xw_integration + from deepchem.utils.dft_utils.grid.radial_grid import SlicedRadialGrid + from deepchem.utils.dft_utils.grid.radial_grid import BaseGridTransform + from deepchem.utils.dft_utils.grid.radial_grid import DE2Transformation + from deepchem.utils.dft_utils.grid.radial_grid import LogM3Transformation + from deepchem.utils.dft_utils.grid.radial_grid import TreutlerM4Transformation + from deepchem.utils.dft_utils.grid.radial_grid import get_grid_transform + from deepchem.utils.dft_utils.xc.base_xc import BaseXC from deepchem.utils.dft_utils.xc.base_xc import AddBaseXC from deepchem.utils.dft_utils.xc.base_xc import MulBaseXC diff --git a/deepchem/utils/dft_utils/grid/radial_grid.py b/deepchem/utils/dft_utils/grid/radial_grid.py new file mode 100644 index 0000000000..de3b8f12db --- /dev/null +++ b/deepchem/utils/dft_utils/grid/radial_grid.py @@ -0,0 +1,560 @@ +from __future__ import annotations +from abc import abstractmethod +import torch +import numpy as np +from typing import Union, Tuple +from deepchem.utils.dft_utils import BaseGrid + + +class RadialGrid(BaseGrid): + """ + Grid for radially symmetric system. This grid consists grid_integrator + and grid_transform specifiers. + + grid_integrator is to specify how to perform an integration on a fixed + interval from -1 to 1. + + grid_transform is to transform the integration from the coordinate of + grid_integrator to the actual coordinate. + + Examples + -------- + >>> grid = RadialGrid(100, grid_integrator="chebyshev", + ... grid_transform="logm3") + >>> grid.get_rgrid().shape + torch.Size([100, 1]) + >>> grid.get_dvolume().shape + torch.Size([100]) + + """ + + def __init__(self, + ngrid: int, + grid_integrator: str = "chebyshev", + grid_transform: Union[str, BaseGridTransform] = "logm3", + dtype: torch.dtype = torch.float64, + device: torch.device = torch.device('cpu')): + """Initialize the RadialGrid. + + Parameters + ---------- + ngrid: int + Number of grid points. + grid_integrator: str (default "chebyshev") + The grid integrator to use. Available options are "chebyshev", + "chebyshev2", and "uniform". + grid_transform: Union[str, BaseGridTransform] (default "logm3") + The grid transformation to use. Available options are "logm3", + "de2", and "treutlerm4". + dtype: torch.dtype, optional (default torch.float64) + The data type to use for the grid. + device: torch.device, optional (default torch.device('cpu')) + The device to use for the grid. + + """ + self._dtype = dtype + self._device = device + grid_transform_obj = get_grid_transform(grid_transform) + + # get the location and weights of the integration in its original + # coordinate + _x, _w = get_xw_integration(ngrid, grid_integrator) + x = torch.as_tensor(_x, dtype=dtype, device=device) + w = torch.as_tensor(_w, dtype=dtype, device=device) + r = grid_transform_obj.x2r(x) # (ngrid,) + + # get the coordinate in Cartesian + r1 = r.unsqueeze(-1) # (ngrid, 1) + self.rgrid = r1 + + # integration element + drdx = grid_transform_obj.get_drdx(x) + vol_elmt = 4 * np.pi * r * r # (ngrid,) + dr = drdx * w + self.dvolume = vol_elmt * dr # (ngrid,) + + @property + def coord_type(self): + """Returns the coordinate type of the grid. + + Returns + ------- + str + The coordinate type of the grid. For RadialGrid, this is "radial". + + """ + return "radial" + + @property + def dtype(self): + """Returns the data type of the grid. + + Returns + ------- + torch.dtype + The data type of the grid. + + """ + return self._dtype + + @property + def device(self): + """Returns the device of the grid. + + Returns + ------- + torch.device + The device of the grid. + + """ + return self._device + + def get_dvolume(self) -> torch.Tensor: + """Returns the integration element of the grid. + + Returns + ------- + torch.Tensor + The integration element of the grid. + + """ + return self.dvolume + + def get_rgrid(self) -> torch.Tensor: + """Returns the grid points. + + Returns + ------- + torch.Tensor + The grid points. + + """ + return self.rgrid + + def __getitem__(self, key: Union[int, slice]) -> RadialGrid: + """Returns a sliced RadialGrid. + + Parameters + ---------- + key: Union[int, slice] + The index or slice to use for slicing the grid. + + Returns + ------- + RadialGrid + The sliced RadialGrid. + + """ + if isinstance(key, slice): + return SlicedRadialGrid(self, key) + else: + raise KeyError("Indexing for RadialGrid is not defined") + + def getparamnames(self, methodname: str, prefix: str = ""): + """Returns the parameter names for the given method. + + Parameters + ---------- + methodname: str + The name of the method. + prefix: str, optional (default "") + The prefix to use for the parameter names. + + Returns + ------- + List[str] + The parameter names for the given method. + + """ + if methodname == "get_dvolume": + return [prefix + "dvolume"] + elif methodname == "get_rgrid": + return [prefix + "rgrid"] + else: + raise KeyError("getparamnames for %s is not set" % methodname) + + +def get_xw_integration(n: int, s0: str) -> Tuple[np.ndarray, np.ndarray]: + """returns ``n`` points of integration from -1 to 1 and its integration + weights + + Examples + -------- + >>> x, w = get_xw_integration(100, "chebyshev") + >>> x.shape + (100,) + >>> w.shape + (100,) + + Parameters + ---------- + n: int + Number of grid points. + s0: str + The grid integrator to use. Available options are `chebyshev`, + `chebyshev2`, and `uniform`. + + Returns + ------- + Tuple[np.ndarray, np.ndarray] + The integration points and weights. + + References + ---------- + .. [1] chebyshev polynomial eq (9) & (10) https://doi.org/10.1063/1.475719 + .. [2] Handbook of Mathematical Functions (Abramowitz & Stegun) p. 889 + + """ + + s = s0.lower() + if s == "chebyshev": + np1 = n + 1. + icount = np.arange(n, 0, -1) + ipn1 = icount * np.pi / np1 + sin_ipn1 = np.sin(ipn1) + sin_ipn1_2 = sin_ipn1 * sin_ipn1 + xcheb = (np1 - 2 * icount) / np1 + 2 / np.pi * \ + (1 + 2. / 3 * sin_ipn1 * sin_ipn1) * np.cos(ipn1) * sin_ipn1 + wcheb = 16. / (3 * np1) * sin_ipn1_2 * sin_ipn1_2 + return xcheb, wcheb + + elif s == "chebyshev2": + np1 = n + 1.0 + icount = np.arange(n, 0, -1) + ipn1 = icount * np.pi / np1 + sin_ipn1 = np.sin(ipn1) + xcheb = np.cos(ipn1) + wcheb = np.pi / np1 * sin_ipn1 + return xcheb, wcheb + + elif s == "uniform": + x = np.linspace(-1, 1, n) + w = np.ones(n) * (x[1] - x[0]) + w[0] *= 0.5 + w[-1] *= 0.5 + return x, w + else: + avail = ["chebyshev", "chebyshev2", "uniform"] + raise RuntimeError("Unknown grid_integrator: %s. Available: %s" % + (s0, avail)) + + +class SlicedRadialGrid(RadialGrid): + """Internal class to represent the sliced radial grid""" + + def __init__(self, obj: RadialGrid, key: slice): + """Initialize the SlicedRadialGrid. + + Parameters + ---------- + obj: RadialGrid + The original RadialGrid. + key: slice + The slice to use for slicing the grid. + + """ + self._dtype = obj.dtype + self._device = obj.device + self.dvolume = obj.dvolume[key] + self.rgrid = obj.rgrid[key] + + +# Grid Transformations + + +class BaseGridTransform(object): + """Base class for grid transformation + Grid transformation is to transform the integration from the coordinate of + grid_integrator to the actual coordinate. + + It is used as a base class for other grid transformations. + x2r and get_drdx are abstract methods that need to be implemented. + + """ + + @abstractmethod + def x2r(self, x: torch.Tensor) -> torch.Tensor: + """Transform from x to r coordinate + + Parameters + ---------- + x: torch.Tensor + The coordinate from -1 to 1. + + Returns + ------- + r: torch.Tensor + The coordinate from 0 to inf. + + """ + pass + + @abstractmethod + def get_drdx(self, x: torch.Tensor) -> torch.Tensor: + """Returns the dr/dx + + Parameters + ---------- + x: torch.Tensor + The coordinate from -1 to 1. + + Returns + ------- + drdx: torch.Tensor + The dr/dx. + + """ + pass + + +class DE2Transformation(BaseGridTransform): + """Double exponential formula grid transformation + + Examples + -------- + >>> x = torch.linspace(-1, 1, 100) + >>> r = DE2Transformation().x2r(x) + >>> r.shape + torch.Size([100]) + >>> drdx = DE2Transformation().get_drdx(x) + >>> drdx.shape + torch.Size([100]) + + References + ---------- + .. [1] eq (31) in https://link.springer.com/article/10.1007/s00214-011-0985-x + + """ + + def __init__(self, + alpha: float = 1.0, + rmin: float = 1e-7, + rmax: float = 20): + assert rmin < 1.0 + self.alpha = alpha + self.xmin = -np.log(-np.log(rmin)) # approximate for small r + self.xmax = np.log(rmax) / alpha # approximate for large r + + def x2r(self, x: torch.Tensor) -> torch.Tensor: + """Transform from x to r coordinate + + Parameters + ---------- + x: torch.Tensor + The coordinate from -1 to 1. + + Returns + ------- + r: torch.Tensor + The coordinate from 0 to inf. + + """ + # xnew is from [xmin, xmax] + xnew = 0.5 * (x * (self.xmax - self.xmin) + (self.xmax + self.xmin)) + # r is approximately from [rmin, rmax] + r = torch.exp(self.alpha * xnew - torch.exp(-xnew)) + return r + + def get_drdx(self, x: torch.Tensor) -> torch.Tensor: + """Returns the dr/dx + + Parameters + ---------- + x: torch.Tensor + The coordinate from -1 to 1. + + Returns + ------- + drdx: torch.Tensor + The dr/dx. + + """ + r = self.x2r(x) + xnew = 0.5 * (x * (self.xmax - self.xmin) + (self.xmax + self.xmin)) + return r * (self.alpha + torch.exp(-xnew)) * (0.5 * + (self.xmax - self.xmin)) + + +class LogM3Transformation(BaseGridTransform): + """LogM3 grid transformation + + Examples + -------- + >>> x = torch.linspace(-1, 1, 100) + >>> r = LogM3Transformation().x2r(x) + >>> r.shape + torch.Size([100]) + >>> drdx = LogM3Transformation().get_drdx(x) + >>> drdx.shape + torch.Size([100]) + + References + ---------- + .. [1] eq (12) in https://aip.scitation.org/doi/pdf/10.1063/1.475719 + + """ + + def __init__(self, ra: float = 1.0, eps: float = 1e-15): + """Initialize the LogM3Transformation. + + Parameters + ---------- + ra: float (default 1.0) + The parameter to control the range of the grid. + eps: float (default 1e-15) + The parameter to avoid numerical instability. + + """ + self.ra = ra + self.eps = eps + self.ln2 = np.log(2.0 + eps) + + def x2r(self, x: torch.Tensor) -> torch.Tensor: + """Transform from x to r coordinate + + Parameters + ---------- + x: torch.Tensor + The coordinate from -1 to 1. + + Returns + ------- + torch.Tensor + The coordinate from 0 to inf. + + """ + return self.ra * (1 - torch.log1p(-x + self.eps) / self.ln2) + + def get_drdx(self, x: torch.Tensor) -> torch.Tensor: + """Returns the dr/dx + + Parameters + ---------- + x: torch.Tensor + The coordinate from -1 to 1. + + Returns + ------- + torch.Tensor + The dr/dx. + + """ + return self.ra / self.ln2 / (1 - x + self.eps) + + +class TreutlerM4Transformation(BaseGridTransform): + """Treutler M4 grid transformation + + Examples + -------- + >>> x = torch.linspace(-1, 1, 100) + >>> r = TreutlerM4Transformation().x2r(x) + >>> r.shape + torch.Size([100]) + >>> drdx = TreutlerM4Transformation().get_drdx(x) + >>> drdx.shape + torch.Size([100]) + + References + ---------- + .. [1] eq (19) in https://doi.org/10.1063/1.469408 + + """ + + def __init__(self, xi: float = 1.0, alpha: float = 0.6, eps: float = 1e-15): + """Initialize the TreutlerM4Transformation. + + Parameters + ---------- + xi: float (default 1.0) + The parameter to control the range of the grid. + alpha: float (default 0.6) + The parameter to control the range of the grid. + eps: float (default 1e-15) + The parameter to avoid numerical instability. + + """ + self._xi = xi + self._alpha = alpha + self._ln2 = np.log(2.0 + eps) + self._eps = eps + + def x2r(self, x: torch.Tensor) -> torch.Tensor: + """Transform from x to r coordinate + + Parameters + ---------- + x: torch.Tensor + The coordinate from -1 to 1. + + Returns + ------- + torch.Tensor + The coordinate from 0 to inf. + + """ + a = 1.0 + self._eps + r = self._xi / self._ln2 * (a + x) ** self._alpha * \ + (self._ln2 - torch.log1p(-x + self._eps)) + return r + + def get_drdx(self, x: torch.Tensor) -> torch.Tensor: + """Returns the dr/dx + + Parameters + ---------- + x: torch.Tensor + The coordinate from -1 to 1. + + Returns + ------- + torch.Tensor + The dr/dx. + + """ + a = 1.0 + self._eps + fac = self._xi / self._ln2 * (a + x)**self._alpha + r1 = fac / (1 - x + self._eps) + r2 = fac * self._alpha / (a + x) * (self._ln2 - + torch.log1p(-x + self._eps)) + return r2 + r1 + + +def get_grid_transform(s0: Union[str, BaseGridTransform]) -> BaseGridTransform: + """grid transformation object from the input + + Examples + -------- + >>> transform = get_grid_transform("logm3") + >>> transform.x2r(torch.tensor([0.5])) + tensor([2.]) + + Parameters + ---------- + s0: Union[str, BaseGridTransform] + The grid transformation to use. Available options are `logm3`, + `de2`, and `treutlerm4`. + + Returns + ------- + BaseGridTransform + The grid transformation object. + + Raises + ------ + RuntimeError + If the input is not a valid grid transformation. + + """ + if isinstance(s0, BaseGridTransform): + return s0 + else: + s = s0.lower() + if s == "logm3": + return LogM3Transformation() + elif s == "de2": + return DE2Transformation() + elif s == "treutlerm4": + return TreutlerM4Transformation() + else: + raise RuntimeError("Unknown grid transformation: %s" % s0) diff --git a/deepchem/utils/test/test_dft_utils.py b/deepchem/utils/test/test_dft_utils.py index f85bca3bc2..c09984a73a 100644 --- a/deepchem/utils/test/test_dft_utils.py +++ b/deepchem/utils/test/test_dft_utils.py @@ -445,3 +445,33 @@ def requires_grid(self): system = MySystem() assert system.requires_grid() + + +@pytest.mark.torch +def test_radial_grid(): + from deepchem.utils.dft_utils import RadialGrid + grid = RadialGrid(4, grid_integrator="chebyshev", grid_transform="logm3") + assert grid.get_rgrid().shape == torch.Size([4, 1]) + assert grid.get_dvolume().shape == torch.Size([4]) + + +@pytest.mark.torch +def test_get_xw_integration(): + from deepchem.utils.dft_utils import get_xw_integration + x, w = get_xw_integration(4, "chebyshev") + assert x.shape == (4, ) + assert w.shape == torch.Size([4]) + + +@pytest.mark.torch +def test_sliced_radial_grid(): + from deepchem.utils.dft_utils import RadialGrid, SlicedRadialGrid + grid = RadialGrid(4) + sliced_grid = SlicedRadialGrid(grid, 2) + assert sliced_grid.get_rgrid().shape == torch.Size([1]) + + +@pytest.mark.torch +def test_base_grid_transform(): + from deepchem.utils.dft_utils import BaseGridTransform + diff --git a/docs/source/api_reference/utils.rst b/docs/source/api_reference/utils.rst index 325c993359..31d1494b86 100644 --- a/docs/source/api_reference/utils.rst +++ b/docs/source/api_reference/utils.rst @@ -347,6 +347,30 @@ The utilites here are used to create an object that contains information about a .. autoclass:: deepchem.utils.dft_utils.system.base_system.BaseSystem :members: +.. autoclass:: deepchem.utils.dft_utils.grid.radial_grid.RadialGrid + :members: + +.. autoclass:: deepchem.utils.dft_utils.grid.radial_grid.get_xw_integration + :members: + +.. autoclass:: deepchem.utils.dft_utils.grid.radial_grid.SlicedRadialGrid + :members: + +.. autoclass:: deepchem.utils.dft_utils.grid.radial_grid.BaseGridTransform + :members: + +.. autoclass:: deepchem.utils.dft_utils.grid.radial_grid.DE2Transformation + :members: + +.. autoclass:: deepchem.utils.dft_utils.grid.radial_grid.LogM3Transformation + :members: + +.. autoclass:: deepchem.utils.dft_utils.grid.radial_grid.TreutlerM4Transformation + :members: + +.. autoclass:: deepchem.utils.dft_utils.grid.radial_grid.get_grid_transform + :members: + .. autoclass:: deepchem.utils.differentiation_utils.editable_module.EditableModule :members: From 2530fd81a02fea354e6bf82fa30e2c56641c032b Mon Sep 17 00:00:00 2001 From: GreatRSingh Date: Fri, 19 Jan 2024 02:25:34 +0530 Subject: [PATCH 2/5] sibling pure functions - test remaining --- .../utils/differentiation_utils/__init__.py | 1 + .../differentiation_utils/pure_function.py | 135 +++++++++++++++++- 2 files changed, 135 insertions(+), 1 deletion(-) diff --git a/deepchem/utils/differentiation_utils/__init__.py b/deepchem/utils/differentiation_utils/__init__.py index 5a03581f3f..e01772281b 100644 --- a/deepchem/utils/differentiation_utils/__init__.py +++ b/deepchem/utils/differentiation_utils/__init__.py @@ -21,5 +21,6 @@ from deepchem.utils.differentiation_utils.pure_function import PureFunction from deepchem.utils.differentiation_utils.pure_function import get_pure_function + from deepchem.utils.differentiation_utils.pure_function import make_sibling except: pass diff --git a/deepchem/utils/differentiation_utils/pure_function.py b/deepchem/utils/differentiation_utils/pure_function.py index cdf1096dd6..c051a9dd41 100644 --- a/deepchem/utils/differentiation_utils/pure_function.py +++ b/deepchem/utils/differentiation_utils/pure_function.py @@ -1,6 +1,6 @@ import torch import inspect -from typing import Callable, List, Tuple, Union +from typing import Callable, List, Tuple, Union, Sequence from deepchem.utils.attribute_utils import set_attr, del_attr from deepchem.utils.differentiation_utils import EditableModule from deepchem.utils.misc_utils import Uniquifier @@ -360,6 +360,111 @@ def _check_identical_objs(objs1: List, objs2: List) -> bool: return False return True +class SingleSiblingPureFunction(PureFunction): + """Implementation of PureFunction for a sibling method + + A sibling method is a method that is virtually belong to the same object, + but behaves differently. + + Changing the state of the decorated function will also change the state of + ``pfunc`` and its other siblings. + + """ + + def __init__(self, fcn: Callable, fcntocall: Callable): + """Initialize the SingleSiblingPureFunction. + + Parameters + ---------- + fcn: Callable + The sibling method to be wrapped + fcntocall: Callable + The method to be wrapped + + """ + self.pfunc = get_pure_function(fcn) + super().__init__(fcntocall) + + def _get_all_obj_params_init(self) -> List: + """Get the initial object parameters. + + Returns + ------- + List + The initial object parameters + + """ + return self.pfunc._get_all_obj_params_init() + + def _set_all_obj_params(self, allobjparams: List): + """Set the object parameters. + + Parameters + ---------- + allobjparams: List + The object parameters to be set + + """ + self.pfunc._set_all_obj_params(allobjparams) + + +class MultiSiblingPureFunction(PureFunction): + """Implementation of PureFunction for multiple sibling methods + + A sibling method is a method that is virtually belong to the same object, + but behaves differently. + + Changing the state of the decorated function will also change the state of + ``pfunc`` and its other siblings. + + """ + + def __init__(self, fcns: Sequence[Callable], fcntocall: Callable): + """Initialize the MultiSiblingPureFunction. + + Parameters + ---------- + fcns: Sequence[Callable] + The sibling methods to be wrapped + fcntocall: Callable + The method to be wrapped + + """ + self.pfuncs = [get_pure_function(fcn) for fcn in fcns] + self.npfuncs = len(self.pfuncs) + super().__init__(fcntocall) + + def _get_all_obj_params_init(self) -> List: + """Get the initial object parameters. + + Returns + ------- + List + The initial object parameters + + """ + res: List[Union[torch.Tensor, torch.nn.Parameter]] = [] + self.cumsum_idx = [0] * (self.npfuncs + 1) + for i, pfunc in enumerate(self.pfuncs): + objparams = pfunc._get_all_obj_params_init() + res = res + objparams + self.cumsum_idx[i + 1] = self.cumsum_idx[i] + len(objparams) + return res + + def _set_all_obj_params(self, allobjparams: List): + """Set the object parameters. + + Parameters + ---------- + allobjparams: List + The object parameters to be set + + """ + for i, pfunc in enumerate(self.pfuncs): + pfunc._set_all_obj_params( + allobjparams[self.cumsum_idx[i]:self.cumsum_idx[i + 1]]) + + def get_pure_function(fcn) -> PureFunction: """Get the pure function form of the function or method ``fcn``. @@ -414,3 +519,31 @@ def get_pure_function(fcn) -> PureFunction: else: raise RuntimeError(errmsg) + + +def make_sibling(*pfuncs) -> Callable[[Callable], PureFunction]: + """ + Used as a decor to mark the decorated function as a sibling method of the + input ``pfunc``. + Sibling method is a method that is virtually belong to the same object, but + behaves differently. + Changing the state of the decorated function will also change the state of + ``pfunc`` and its other siblings. + + Parameters + ---------- + pfuncs: List[Callable] + The sibling methods to be wrapped + + Returns + ------- + Callable[[Callable], PureFunction] + The decorator function + + """ + if len(pfuncs) == 0: + raise TypeError("At least 1 function is required as the argument") + elif len(pfuncs) == 1: + return lambda fcn: SingleSiblingPureFunction(pfuncs[0], fcntocall=fcn) + else: + return lambda fcn: MultiSiblingPureFunction(pfuncs, fcntocall=fcn) From bfdc85d414ac9f1135da2aac70542d8e483e5567 Mon Sep 17 00:00:00 2001 From: GreatRSingh Date: Fri, 19 Jan 2024 02:34:14 +0530 Subject: [PATCH 3/5] solve helper - 1 Full --- deepchem/utils/__init__.py | 1 + .../utils/differentiation_utils/__init__.py | 7 + deepchem/utils/differentiation_utils/solve.py | 293 ++++++++++++++++++ deepchem/utils/pytorch_utils.py | 32 ++ deepchem/utils/safeops_utils.py | 29 ++ .../utils/test/test_differentiation_utils.py | 57 ++++ deepchem/utils/test/test_misc_utils.py | 2 +- deepchem/utils/test/test_pytorch_utils.py | 7 + docs/source/api_reference/utils.rst | 14 + 9 files changed, 441 insertions(+), 1 deletion(-) create mode 100644 deepchem/utils/differentiation_utils/solve.py diff --git a/deepchem/utils/__init__.py b/deepchem/utils/__init__.py index c067839589..b7cb962e61 100644 --- a/deepchem/utils/__init__.py +++ b/deepchem/utils/__init__.py @@ -141,6 +141,7 @@ from deepchem.utils.pytorch_utils import TensorNonTensorSeparator from deepchem.utils.pytorch_utils import tallqr from deepchem.utils.pytorch_utils import to_fortran_order + from deepchem.utils.pytorch_utils import get_np_dtype from deepchem.utils.safeops_utils import safepow from deepchem.utils.safeops_utils import safenorm diff --git a/deepchem/utils/differentiation_utils/__init__.py b/deepchem/utils/differentiation_utils/__init__.py index e01772281b..172546ea3e 100644 --- a/deepchem/utils/differentiation_utils/__init__.py +++ b/deepchem/utils/differentiation_utils/__init__.py @@ -22,5 +22,12 @@ from deepchem.utils.differentiation_utils.pure_function import PureFunction from deepchem.utils.differentiation_utils.pure_function import get_pure_function from deepchem.utils.differentiation_utils.pure_function import make_sibling + + from deepchem.utils.differentiation_utils.solve import wrap_gmres + from deepchem.utils.differentiation_utils.solve import exactsolve + from deepchem.utils.differentiation_utils.solve import solve_ABE + from deepchem.utils.differentiation_utils.solve import get_batchdims + from deepchem.utils.differentiation_utils.solve import setup_precond + from deepchem.utils.differentiation_utils.solve import dot except: pass diff --git a/deepchem/utils/differentiation_utils/solve.py b/deepchem/utils/differentiation_utils/solve.py new file mode 100644 index 0000000000..02768478bb --- /dev/null +++ b/deepchem/utils/differentiation_utils/solve.py @@ -0,0 +1,293 @@ +import numpy as np +import torch +import warnings +from typing import Union, Optional, Callable +from deepchem.utils.differentiation_utils import LinearOperator, normalize_bcast_dims, get_bcasted_dims +from deepchem.utils import ConvergenceWarning, get_np_dtype +from scipy.sparse.linalg import gmres as scipy_gmres + + +# Hidden +def wrap_gmres(A, B, E=None, M=None, min_eps=1e-9, max_niter=None, **unused): + """ + Using SciPy's gmres method to solve the linear equation. + + Examples + -------- + >>> import torch + >>> from deepchem.utils.differentiation_utils import LinearOperator + >>> A = LinearOperator.m(torch.tensor([[1., 2], [3, 4]])) + >>> B = torch.tensor([[[5., 6], [7, 8]]]) + >>> wrap_gmres(A, B, None, None) + tensor([[[-3.0000, -4.0000], + [ 4.0000, 5.0000]]]) + + Parameters + ---------- + A: LinearOperator + The linear operator A to be solved. Shape: (*BA, na, na) + B: torch.Tensor + Batched matrix B. Shape: (*BB, na, ncols) + E: torch.Tensor or None + Batched vector E. Shape: (*BE, ncols) + M: LinearOperator or None + The linear operator M. Shape: (*BM, na, na) + min_eps: float + Relative tolerance for stopping conditions + max_niter: int or None + Maximum number of iterations. If ``None``, default to twice of the + number of columns of ``A``. + + Returns + ------- + torch.Tensor + The Solution matrix X. Shape: (*BBE, na, ncols) + + """ + + # NOTE: currently only works for batched B (1 batch dim), but unbatched A + assert len(A.shape) == 2 and len( + B.shape + ) == 3, "Currently only works for batched B (1 batch dim), but unbatched A" + assert not torch.is_complex(B), "complex is not supported in gmres" + + # check the parameters + msg = "GMRES can only do AX=B" + assert A.shape[-2] == A.shape[ + -1], "GMRES can only work for square operator for now" + assert E is None, msg + assert M is None, msg + + nbatch, na, ncols = B.shape + if max_niter is None: + max_niter = 2 * na + + B = B.transpose(-1, -2) # (nbatch, ncols, na) + + # convert the numpy/scipy + op = A.scipy_linalg_op() + B_np = B.detach().cpu().numpy() + res_np = np.empty(B.shape, dtype=get_np_dtype(B.dtype)) + for i in range(nbatch): + for j in range(ncols): + x, info = scipy_gmres(op, + B_np[i, j, :], + tol=min_eps, + atol=1e-12, + maxiter=max_niter) + if info > 0: + msg = "The GMRES iteration does not converge to the desired value "\ + "(%.3e) after %d iterations" % \ + (min_eps, info) + warnings.warn(ConvergenceWarning(msg)) + res_np[i, j, :] = x + + res = torch.tensor(res_np, dtype=B.dtype, device=B.device) + res = res.transpose(-1, -2) # (nbatch, na, ncols) + return res + + +def exactsolve(A: LinearOperator, B: torch.Tensor, E: Union[torch.Tensor, None], + M: Union[LinearOperator, None]): + """ + Solve the linear equation by contructing the full matrix of LinearOperators. + + Examples + -------- + >>> import torch + >>> from deepchem.utils.differentiation_utils import LinearOperator + >>> A = LinearOperator.m(torch.tensor([[1., 2], [3, 4]])) + >>> B = torch.tensor([[5., 6], [7, 8]]) + >>> exactsolve(A, B, None, None) + tensor([[-3., -4.], + [ 4., 5.]]) + + Parameters + ---------- + A: LinearOperator + The linear operator A to be solved. Shape: (*BA, na, na) + B: torch.Tensor + Batched matrix B. Shape: (*BB, na, ncols) + E: torch.Tensor or None + Batched vector E. Shape: (*BE, ncols) + M: LinearOperator or None + The linear operator M. Shape: (*BM, na, na) + + Returns + ------- + torch.Tensor + The Solution matrix X. Shape: (*BBE, na, ncols) + + Warnings + -------- + * As this method construct the linear operators explicitly, it might requires + a large memory. + + """ + if E is None: + Amatrix = A.fullmatrix() + x = torch.linalg.solve(Amatrix, B) + elif M is None: + Amatrix = A.fullmatrix() + x = solve_ABE(Amatrix, B, E) + else: + Mmatrix = M.fullmatrix() + L = torch.linalg.cholesky(Mmatrix) + Linv = torch.inverse(L) + LinvT = Linv.transpose(-2, -1).conj() + A2 = torch.matmul(Linv, A.mm(LinvT)) + B2 = torch.matmul(Linv, B) + + X2 = solve_ABE(A2, B2, E) + x = torch.matmul(LinvT, X2) + return x + + +def solve_ABE(A: torch.Tensor, B: torch.Tensor, E: torch.Tensor): + """ Solve the linear equation AX = B - diag(E)X. + + Examples + -------- + >>> import torch + >>> A = torch.tensor([[1., 2], [3, 4]]) + >>> B = torch.tensor([[5., 6], [7, 8]]) + >>> E = torch.tensor([1., 2]) + >>> solve_ABE(A, B, E) + tensor([[-0.1667, 0.5000], + [ 2.5000, 3.2500]]) + + Parameters + ---------- + A: torch.Tensor + The batched matrix A. Shape: (*BA, na, na) + B: torch.Tensor + The batched matrix B. Shape: (*BB, na, ncols) + E: torch.Tensor + The batched vector E. Shape: (*BE, ncols) + + Returns + ------- + torch.Tensor + The batched matrix X. + + """ + na = A.shape[-1] + BA, BB, BE = normalize_bcast_dims(A.shape[:-2], B.shape[:-2], E.shape[:-1]) + E = E.reshape(1, *BE, E.shape[-1]).transpose(0, -1) # (ncols, *BE, 1) + B = B.reshape(1, *BB, *B.shape[-2:]).transpose(0, -1) # (ncols, *BB, na, 1) + + # NOTE: The line below is very inefficient for large na and ncols + AE = A - torch.diag_embed(E.repeat_interleave(repeats=na, dim=-1), + dim1=-2, + dim2=-1) # (ncols, *BAE, na, na) + r = torch.linalg.solve(AE, B) # (ncols, *BAEM, na, 1) + r = r.transpose(0, -1).squeeze(0) # (*BAEM, na, ncols) + return r + + +# general helpers +def get_batchdims(A: LinearOperator, B: torch.Tensor, + E: Union[torch.Tensor, None], M: Union[LinearOperator, None]): + """Get the batch dimensions of the linear operator and the matrix B + + Examples + -------- + >>> from deepchem.utils.differentiation_utils import MatrixLinearOperator + >>> import torch + >>> A = MatrixLinearOperator(torch.randn(4, 3, 3), True) + >>> B = torch.randn(3, 3, 2) + >>> get_batchdims(A, B, None, None) + [4] + + Parameters + ---------- + A: LinearOperator + The linear operator. It can be a batched linear operator. + B: torch.Tensor + The matrix B. It can be a batched matrix. + E: Union[torch.Tensor, None] + The matrix E. It can be a batched matrix. + M: Union[LinearOperator, None] + The linear operator M. It can be a batched linear operator. + + Returns + ------- + List[int] + The batch dimensions of the linear operator and the matrix B + + """ + + batchdims = [A.shape[:-2], B.shape[:-2]] + if E is not None: + batchdims.append(E.shape[:-1]) + if M is not None: + batchdims.append(M.shape[:-2]) + return get_bcasted_dims(*batchdims) + + +def setup_precond( + precond: Optional[LinearOperator] = None +) -> Callable[[torch.Tensor], torch.Tensor]: + """Setup the preconditioning function + + Examples + -------- + >>> from deepchem.utils.differentiation_utils import MatrixLinearOperator + >>> import torch + >>> A = MatrixLinearOperator(torch.randn(4, 3, 3), True) + >>> B = torch.randn(4, 3, 2) + >>> cond = setup_precond(A) + >>> cond(B).shape + torch.Size([4, 3, 2]) + + Parameters + ---------- + precond: Optional[LinearOperator] + The preconditioning linear operator. If None, no preconditioning is + applied. + + Returns + ------- + Callable[[torch.Tensor], torch.Tensor] + The preconditioning function. It takes a tensor and returns a tensor. + + """ + if isinstance(precond, LinearOperator): + + def precond_fcn(x): + return precond.mm(x) + elif precond is None: + + def precond_fcn(x): + return x + else: + raise TypeError("precond can only be LinearOperator or None") + return precond_fcn + + +def dot(r: torch.Tensor, z: torch.Tensor) -> torch.Tensor: + """Dot product of two vectors. r and z must have the same shape. + Then sums it up across the last dimension. + + Examples + -------- + >>> import torch + >>> r = torch.tensor([[1, 2], [3, 4]]) + >>> z = torch.tensor([[5, 6], [7, 8]]) + >>> dot(r, z) + tensor([[26, 44]]) + + Parameters + ---------- + r: torch.Tensor + The first vector. Shape: (*BR, nr, nc) + z: torch.Tensor + The second vector. Shape: (*BR, nr, nc) + + Returns + ------- + torch.Tensor + The dot product of r and z. Shape: (*BR, 1, nc) + + """ + return torch.einsum("...rc,...rc->...c", r.conj(), z).unsqueeze(-2) diff --git a/deepchem/utils/pytorch_utils.py b/deepchem/utils/pytorch_utils.py index 9390736298..ea2eb4cf59 100644 --- a/deepchem/utils/pytorch_utils.py +++ b/deepchem/utils/pytorch_utils.py @@ -3,6 +3,7 @@ import scipy import torch from typing import Callable, Union, List, Generator, Tuple +import numpy as np def get_activation(fn: Union[Callable, str]): @@ -433,3 +434,34 @@ def to_fortran_order(V): else: raise RuntimeError( "Only the last two dimensions can be made Fortran order.") + +def get_np_dtype(dtype: torch.dtype) -> np.dtype: + """corresponding numpy dtype from the input pytorch's tensor dtype + + >>> from deepchem.utils.pytorch_utils import get_np_dtype + >>> get_np_dtype(torch.float32) + + >>> get_np_dtype(torch.float64) + + + Parameters + ---------- + dtype: torch.dtype + pytorch's tensor dtype + + Returns + ------- + np.dtype + corresponding numpy dtype + + """ + if dtype == torch.float32: + return np.float32 + elif dtype == torch.float64: + return np.float64 + elif dtype == torch.complex64: + return np.complex64 + elif dtype == torch.complex128: + return np.complex128 + else: + raise TypeError("Unknown type: %s" % dtype) diff --git a/deepchem/utils/safeops_utils.py b/deepchem/utils/safeops_utils.py index f4246c4b77..3c7fa49d9e 100644 --- a/deepchem/utils/safeops_utils.py +++ b/deepchem/utils/safeops_utils.py @@ -347,3 +347,32 @@ def safe_cdist(a: torch.Tensor, ab = ab + infdiag return ab + + +def safedenom(r: torch.Tensor, eps: float) -> torch.Tensor: + """Avoid division by zero by replacing zero elements with eps. + + Used in CG and BiCGStab. + + Examples + -------- + >>> import torch + >>> r = torch.tensor([1e-11, 0]) + >>> safedenom(r, 1e-12) + tensor([1.0000e-11, 1.0000e-12]) + + Parameters + ---------- + r: torch.Tensor + The residual vector + eps: float + The minimum value to avoid division by zero + + Returns + ------- + r: torch.Tensor + The residual vector with zero elements replaced by eps + + """ + r[r == 0] = eps + return r diff --git a/deepchem/utils/test/test_differentiation_utils.py b/deepchem/utils/test/test_differentiation_utils.py index 8481aba17d..c5b1b98b82 100644 --- a/deepchem/utils/test/test_differentiation_utils.py +++ b/deepchem/utils/test/test_differentiation_utils.py @@ -641,3 +641,60 @@ def fcn(x, y): pfunc = get_pure_function(fcn) assert pfunc(1, 2) == 3 + + +@pytest.mark.torch +def test_wrap_gmres(): + from deepchem.utils.differentiation_utils.solve import wrap_gmres + from deepchem.utils.differentiation_utils import LinearOperator + A = LinearOperator.m(torch.tensor([[1., 2], [3, 4]])) + B = torch.tensor([[[5., 6], [7, 8]]]) + assert torch.allclose(A.fullmatrix() @ wrap_gmres(A, B, None, None), B) + + +@pytest.mark.torch +def test_exact_solve(): + from deepchem.utils.differentiation_utils.solve import exactsolve + from deepchem.utils.differentiation_utils import LinearOperator + A = LinearOperator.m(torch.tensor([[1., 2], [3, 4]])) + B = torch.tensor([[5., 6], [7, 8]]) + assert torch.allclose(A.fullmatrix() @ exactsolve(A, B, None, None), B) + + +@pytest.mark.torch +def test_solve_ABE(): + from deepchem.utils.differentiation_utils.solve import solve_ABE + A = torch.tensor([[1., 2], [3, 4]]) + B = torch.tensor([[5., 6], [7, 8]]) + E = torch.tensor([1., 2]) + expected_result = torch.tensor([[-0.1667, 0.5000], [2.5000, 3.2500]]) + assert torch.allclose(solve_ABE(A, B, E), expected_result, 0.001) + + +@pytest.mark.torch +def test_get_batch_dims(): + from deepchem.utils.differentiation_utils.solve import get_batchdims + from deepchem.utils.differentiation_utils import MatrixLinearOperator + A = MatrixLinearOperator(torch.randn(4, 3, 3), True) + B = torch.randn(3, 3, 2) + assert get_batchdims(A, B, None, + None) == [max(A.shape[:-2], B.shape[:-2])[0]] + + +@pytest.mark.torch +def test_setup_precond(): + from deepchem.utils.differentiation_utils.solve import setup_precond + from deepchem.utils.differentiation_utils import MatrixLinearOperator + A = MatrixLinearOperator(torch.randn(4, 3, 3), True) + B = torch.randn(4, 3, 2) + cond = setup_precond(A) + assert cond(B).shape == torch.Size([4, 3, 2]) + + +@pytest.mark.torch +def test_dot(): + from deepchem.utils.differentiation_utils.solve import dot + r = torch.tensor([[1, 2], [3, 4]]) + z = torch.tensor([[5, 6], [7, 8]]) + assert torch.allclose(dot(r, z), torch.tensor([[26, 44]])) + assert torch.allclose(dot(r, z), sum(r * z)) diff --git a/deepchem/utils/test/test_misc_utils.py b/deepchem/utils/test/test_misc_utils.py index e155bb5766..8c3159ece9 100644 --- a/deepchem/utils/test/test_misc_utils.py +++ b/deepchem/utils/test/test_misc_utils.py @@ -5,4 +5,4 @@ def test_uniquifier(): c = 3 d = 1 u = Uniquifier([a, b, c, a, d]) - u.get_unique_objs() == [1, 2, 3] + assert u.get_unique_objs() == [1, 2, 3] diff --git a/deepchem/utils/test/test_pytorch_utils.py b/deepchem/utils/test/test_pytorch_utils.py index 1fff2bef11..610d747c84 100644 --- a/deepchem/utils/test/test_pytorch_utils.py +++ b/deepchem/utils/test/test_pytorch_utils.py @@ -109,3 +109,10 @@ def test_to_fortran_order(): if V.is_contiguous() is True: assert False assert V.shape == torch.Size([3, 2]) + + +@pytest.mark.torch +def test_get_np_dtype(): + """Test the get_np_dtype utility.""" + assert dc.utils.pytorch_utils.get_np_dtype(torch.float32) == np.float32 + assert dc.utils.pytorch_utils.get_np_dtype(torch.float64) == np.float64 diff --git a/docs/source/api_reference/utils.rst b/docs/source/api_reference/utils.rst index 31d1494b86..f8d1ee16db 100644 --- a/docs/source/api_reference/utils.rst +++ b/docs/source/api_reference/utils.rst @@ -437,6 +437,18 @@ The utilites here are used to create an object that contains information about a .. autofunction:: deepchem.utils.differentiation_utils.davidson +.. autofunction:: deepchem.utils.differentiation_utils.solve.wrap_gmres + +.. autofunction:: deepchem.utils.differentiation_utils.solve.exactsolve + +.. autofunction:: deepchem.utils.differentiation_utils.solve.solve_ABE + +.. autofunction:: deepchem.utils.differentiation_utils.solve.get_batchdims + +.. autofunction:: deepchem.utils.differentiation_utils.solve.setup_precond + +.. autofunction:: deepchem.utils.differentiation_utils.solve.dot + Attribute Utilities ------------------- @@ -470,6 +482,8 @@ Pytorch Utilities .. autofunction:: deepchem.utils.pytorch_utils.to_fortran_order +.. autofunction:: deepchem.utils.pytorch_utils.get_np_dtype + Batch Utilities --------------- From d5837defdfae757b2df5a158297800e1b1b2c160 Mon Sep 17 00:00:00 2001 From: GreatRSingh Date: Fri, 19 Jan 2024 02:51:28 +0530 Subject: [PATCH 4/5] solve helper - 2 Full --- .../utils/differentiation_utils/__init__.py | 4 + deepchem/utils/differentiation_utils/solve.py | 361 +++++++++++++++++- deepchem/utils/pytorch_utils.py | 6 +- .../utils/test/test_differentiation_utils.py | 42 ++ docs/source/api_reference/utils.rst | 8 + 5 files changed, 419 insertions(+), 2 deletions(-) diff --git a/deepchem/utils/differentiation_utils/__init__.py b/deepchem/utils/differentiation_utils/__init__.py index 172546ea3e..1c5a5a988e 100644 --- a/deepchem/utils/differentiation_utils/__init__.py +++ b/deepchem/utils/differentiation_utils/__init__.py @@ -29,5 +29,9 @@ from deepchem.utils.differentiation_utils.solve import get_batchdims from deepchem.utils.differentiation_utils.solve import setup_precond from deepchem.utils.differentiation_utils.solve import dot + from deepchem.utils.differentiation_utils.solve import get_largest_eival + from deepchem.utils.differentiation_utils.solve import safedenom + from deepchem.utils.differentiation_utils.solve import setup_linear_problem + from deepchem.utils.differentiation_utils.solve import gmres except: pass diff --git a/deepchem/utils/differentiation_utils/solve.py b/deepchem/utils/differentiation_utils/solve.py index 02768478bb..24670175d7 100644 --- a/deepchem/utils/differentiation_utils/solve.py +++ b/deepchem/utils/differentiation_utils/solve.py @@ -1,7 +1,7 @@ import numpy as np import torch import warnings -from typing import Union, Optional, Callable +from typing import Sequence, Tuple, Union, Optional, Callable from deepchem.utils.differentiation_utils import LinearOperator, normalize_bcast_dims, get_bcasted_dims from deepchem.utils import ConvergenceWarning, get_np_dtype from scipy.sparse.linalg import gmres as scipy_gmres @@ -185,6 +185,149 @@ def solve_ABE(A: torch.Tensor, B: torch.Tensor, E: torch.Tensor): return r +def gmres(A: LinearOperator, + B: torch.Tensor, + E: Optional[torch.Tensor] = None, + M: Optional[LinearOperator] = None, + posdef: Optional[bool] = None, + max_niter: Optional[int] = None, + rtol: float = 1e-6, + atol: float = 1e-8, + eps: float = 1e-12, + **unused) -> torch.Tensor: + r""" + Solve the linear equations using Generalised minial residual method. + + Examples + -------- + >>> import torch + >>> from deepchem.utils.differentiation_utils import LinearOperator + >>> A = LinearOperator.m(torch.tensor([[1., 2], [3, 4]])) + >>> B = torch.tensor([[5., 6], [7, 8]]) + >>> gmres(A, B) + tensor([[0.8959, 1.0697], + [1.2543, 1.4263]]) + + Parameters + ---------- + A: LinearOperator + The linear operator A to be solved. Shape: (*BA, na, na) + B: torch.Tensor + Batched matrix B. Shape: (*BB, na, ncols) + E: torch.Tensor or None + Batched vector E. Shape: (*BE, ncols) + M: LinearOperator or None + The linear operator M. Shape: (*BM, na, na) + posdef: bool or None + Indicating if the operation :math:`\mathbf{AX-MXE}` a positive + definite for all columns and batches. + If None, it will be determined by power iterations. + max_niter: int or None + Maximum number of iteration. If None, it is set to ``int(1.5 * A.shape[-1])`` + rtol: float + Relative tolerance for stopping condition w.r.t. norm of B + atol: float + Absolute tolerance for stopping condition w.r.t. norm of B + eps: float + Substitute the absolute zero in the algorithm's denominator with this + value to avoid nan. + + Returns + ------- + torch.Tensor + The solution matrix X. Shape: (*BBE, na, ncols) + + """ + converge = False + + nr, ncols = A.shape[-1], B.shape[-1] + if max_niter is None: + max_niter = int(nr) + + # if B is all zeros, then return zeros + batchdims = get_batchdims(A, B, E, M) + if torch.allclose(B, B * 0, rtol=rtol, atol=atol): + x0 = torch.zeros((*batchdims, nr, ncols), + dtype=A.dtype, + device=A.device) + return x0 + + # setup the preconditioning and the matrix problem + need_hermit = False + A_fcn, AT_fcn, B2, col_swapped = setup_linear_problem( + A, B, E, M, batchdims, posdef, need_hermit) + + # get the stopping matrix + B_norm = B2.norm(dim=-2, keepdim=True) # (*BB, 1, nc) + stop_matrix = torch.max(rtol * B_norm, + atol * torch.ones_like(B_norm)) # (*BB, 1, nc) + + # prepare the initial guess (it's just all zeros) + x0shape = (ncols, *batchdims, nr, 1) if col_swapped else (*batchdims, nr, + ncols) + x0 = torch.zeros(x0shape, dtype=A.dtype, device=A.device) + + r = B2 - A_fcn(x0) # torch.Size([*batch_dims, nr, ncols]) + best_resid = r.norm(dim=-2, keepdim=True) # / B_norm + + best_resid = best_resid.max().item() + best_res = x0 + q = torch.empty([max_niter] + list(r.shape), dtype=A.dtype, device=A.device) + q[0] = r / safedenom(r.norm(dim=-2, keepdim=True), + eps) # torch.Size([*batch_dims, nr, ncols]) + h = torch.zeros((*batchdims, ncols, max_niter + 1, max_niter), + dtype=A.dtype, + device=A.device) + h = h.reshape((-1, ncols, max_niter + 1, max_niter)) + + for k in range(min(nr, max_niter)): + y = A_fcn(q[k]) # torch.Size([*batch_dims, nr, ncols]) + for j in range(k + 1): + h[..., j, k] = dot(q[j], y).reshape(-1, ncols) + y = y - h[..., j, k].reshape(*batchdims, 1, ncols) * q[j] + + h[..., k + 1, k] = torch.linalg.norm(y, dim=-2) + if torch.any(h[..., k + 1, k]) != 0 and k != max_niter - 1: + q[k + 1] = y.reshape(-1, nr, ncols) / h[..., k + 1, k].reshape( + -1, 1, ncols) + q[k + 1] = q[k + 1].reshape(*batchdims, nr, ncols) + + b = torch.zeros((*batchdims, ncols, k + 1), + dtype=A.dtype, + device=A.device) + b = b.reshape(-1, ncols, k + 1) + b[..., 0] = torch.linalg.norm(r, dim=-2) + rk = torch.linalg.lstsq(h[..., :k + 1, :k], b)[0] + + res = torch.empty([]) + for i in range(k): + res = res + q[i] * rk[..., i].reshape(*batchdims, 1, ncols) + x0 if res.size() \ + else q[i] * rk[..., i].reshape(*batchdims, 1, ncols) + x0 + # res = res * B_norm + + if res.size(): + resid = B2 - A_fcn(res) + resid_norm = resid.norm(dim=-2, keepdim=True) + + # save the best results + max_resid_norm = resid_norm.max().item() + if max_resid_norm < best_resid: + best_resid = max_resid_norm + best_res = res + + if torch.all(resid_norm < stop_matrix): + converge = True + break + + if not converge: + msg = ("Convergence is not achieved after %d iterations. " + "Max norm of resid: %.3e") % (max_niter, best_resid) + warnings.warn(ConvergenceWarning(msg)) + + res = best_res + return res + + # general helpers def get_batchdims(A: LinearOperator, B: torch.Tensor, E: Union[torch.Tensor, None], M: Union[LinearOperator, None]): @@ -265,6 +408,177 @@ def precond_fcn(x): return precond_fcn +def setup_linear_problem(A: LinearOperator, B: torch.Tensor, + E: Optional[torch.Tensor], M: Optional[LinearOperator], + batchdims: Sequence[int], + posdef: Optional[bool], + need_hermit: bool) -> \ + Tuple[Callable[[torch.Tensor], torch.Tensor], + Callable[[torch.Tensor], torch.Tensor], + torch.Tensor, bool]: + """Setup the linear problem for solving AX = B + + Examples + -------- + >>> from deepchem.utils.differentiation_utils import MatrixLinearOperator + >>> import torch + >>> A = MatrixLinearOperator(torch.randn(4, 3, 3), True) + >>> B = torch.randn(4, 3, 2) + >>> A_fcn, AT_fcn, B_new, col_swapped = setup_linear_problem(A, B, None, None, [4], None, False) + >>> A_fcn(B).shape + torch.Size([4, 3, 2]) + + Parameters + ---------- + A: LinearOperator + The linear operator A. It can be a batched linear operator. + B: torch.Tensor + The matrix B. It can be a batched matrix. + E: Optional[torch.Tensor] + The matrix E. It can be a batched matrix. + M: Optional[LinearOperator] + The linear operator M. It can be a batched linear operator. + batchdims: Sequence[int] + The batch dimensions of the linear operator and the matrix B + posdef: Optional[bool] + Whether the linear operator is positive definite. If None, it will be + estimated. + need_hermit: bool + Whether the linear operator is Hermitian. If True, it will be estimated. + + Returns + ------- + Tuple[Callable[[torch.Tensor], torch.Tensor], + Callable[[torch.Tensor], torch.Tensor], + torch.Tensor, bool] + The function A, its transposed function, the matrix B, and whether the + columns of B are swapped. + + """ + + # get the linear operator (including the MXE part) + if E is None: + + def A_fcn(x): + return A.mm(x) + + def AT_fcn(x): + return A.rmm(x) + + B_new = B + col_swapped = False + else: + # A: (*BA, nr, nr) linop + # B: (*BB, nr, ncols) + # E: (*BE, ncols) + # M: (*BM, nr, nr) linop + if M is None: + BAs, BBs, BEs = normalize_bcast_dims(A.shape[:-2], B.shape[:-2], + E.shape[:-1]) + else: + BAs, BBs, BEs, BMs = normalize_bcast_dims(A.shape[:-2], + B.shape[:-2], + E.shape[:-1], + M.shape[:-2]) + E = E.reshape(*BEs, *E.shape[-1:]) + E_new = E.unsqueeze(0).transpose(-1, + 0).unsqueeze(-1) # (ncols, *BEs, 1, 1) + B = B.reshape(*BBs, *B.shape[-2:]) # (*BBs, nr, ncols) + B_new = B.unsqueeze(0).transpose(-1, 0) # (ncols, *BBs, nr, 1) + + def A_fcn(x): + # x: (ncols, *BX, nr, 1) + Ax = A.mm(x) # (ncols, *BAX, nr, 1) + Mx = M.mm(x) if M is not None else x # (ncols, *BMX, nr, 1) + MxE = Mx * E_new # (ncols, *BMXE, nr, 1) + return Ax - MxE + + def AT_fcn(x): + # x: (ncols, *BX, nr, 1) + ATx = A.rmm(x) + MTx = M.rmm(x) if M is not None else x + MTxE = MTx * E_new + return ATx - MTxE + + col_swapped = True + + # estimate if it's posdef with power iteration + if need_hermit: + is_hermit = A.is_hermitian and (M is None or M.is_hermitian) + if not is_hermit: + # set posdef to False to make the operator becomes AT * A so it is + # hermitian + posdef = False + + # TODO: the posdef check by largest eival only works for Hermitian/symmetric + # matrix, but it doesn't always work for non-symmetric matrix. + # In non-symmetric case, one need to do Cholesky LDL decomposition + if posdef is None: + nr, ncols = B.shape[-2:] + x0shape = (ncols, *batchdims, nr, 1) if col_swapped else (*batchdims, + nr, ncols) + x0 = torch.randn(x0shape, dtype=A.dtype, device=A.device) + x0 = x0 / x0.norm(dim=-2, keepdim=True) + largest_eival = get_largest_eival(A_fcn, x0) # (*, 1, nc) + negeival = largest_eival <= 0 + + # if the largest eigenvalue is negative, then it's not posdef + if torch.all(negeival): + posdef = False + + # otherwise, calculate the lowest eigenvalue to check if it's positive + else: + offset = torch.clamp(largest_eival, min=0.0) + + def A_fcn2(x): + return A_fcn(x) - offset * x + + mostneg_eival = get_largest_eival(A_fcn2, x0) # (*, 1, nc) + posdef = bool( + torch.all(torch.logical_or(-mostneg_eival <= offset, + negeival)).item()) + + # get the linear operation if it is not a posdef (A -> AT.A) + if posdef: + return A_fcn, AT_fcn, B_new, col_swapped + else: + + def A_new_fcn(x): + return AT_fcn(A_fcn(x)) + + B2 = AT_fcn(B_new) + return A_new_fcn, A_new_fcn, B2, col_swapped + + +# cg and bicgstab helpers +def safedenom(r: torch.Tensor, eps: float) -> torch.Tensor: + """Make sure the denominator is not zero + + Examples + -------- + >>> import torch + >>> r = torch.tensor([[0., 2], [3, 4]]) + >>> safedenom(r, 1e-9) + tensor([[1.0000e-09, 2.0000e+00], + [3.0000e+00, 4.0000e+00]]) + + Parameters + ---------- + r: torch.Tensor + The input tensor. Shape: (*BR, nr, nc) + eps: float + The small number to replace the zero denominator + + Returns + ------- + torch.Tensor + The tensor with non-zero denominator. Shape: (*BR, nr, nc) + + """ + r[r == 0] = eps + return r + + def dot(r: torch.Tensor, z: torch.Tensor) -> torch.Tensor: """Dot product of two vectors. r and z must have the same shape. Then sums it up across the last dimension. @@ -291,3 +605,48 @@ def dot(r: torch.Tensor, z: torch.Tensor) -> torch.Tensor: """ return torch.einsum("...rc,...rc->...c", r.conj(), z).unsqueeze(-2) + + +def get_largest_eival(Afcn: Callable, x: torch.Tensor) -> torch.Tensor: + """Get the largest eigenvalue of the linear operator Afcn + + Examples + -------- + >>> import torch + >>> def Afcn(x): + ... return 10 * x + >>> x = torch.tensor([[1., 2], [3, 4]]) + >>> get_largest_eival(Afcn, x) + tensor([[10., 10.]]) + + Parameters + ---------- + Afcn: Callable + The linear operator A. It takes a tensor and returns a tensor. + x: torch.Tensor + The input tensor. Shape: (*, nr, nc) + + Returns + ------- + torch.Tensor + The largest eigenvalue. Shape: (*, 1, nc) + + """ + niter = 10 + rtol = 1e-3 + atol = 1e-6 + xnorm_prev = None + for i in range(niter): + x = Afcn(x) # (*, nr, nc) + xnorm = x.norm(dim=-2, keepdim=True) # (*, 1, nc) + + # check if xnorm is converging + if i > 0: + dnorm = torch.abs(xnorm_prev - xnorm) + if torch.all(dnorm <= rtol * xnorm + atol): + break + + xnorm_prev = xnorm + if i < niter - 1: + x = x / xnorm + return xnorm diff --git a/deepchem/utils/pytorch_utils.py b/deepchem/utils/pytorch_utils.py index ea2eb4cf59..702513bf1a 100644 --- a/deepchem/utils/pytorch_utils.py +++ b/deepchem/utils/pytorch_utils.py @@ -2,7 +2,7 @@ import scipy import torch -from typing import Callable, Union, List, Generator, Tuple +from typing import Any, Callable, Union, List, Generator, Tuple import numpy as np @@ -435,9 +435,13 @@ def to_fortran_order(V): raise RuntimeError( "Only the last two dimensions can be made Fortran order.") + def get_np_dtype(dtype: torch.dtype) -> np.dtype: """corresponding numpy dtype from the input pytorch's tensor dtype + Examples + -------- + >>> import torch >>> from deepchem.utils.pytorch_utils import get_np_dtype >>> get_np_dtype(torch.float32) diff --git a/deepchem/utils/test/test_differentiation_utils.py b/deepchem/utils/test/test_differentiation_utils.py index c5b1b98b82..a17cc6d073 100644 --- a/deepchem/utils/test/test_differentiation_utils.py +++ b/deepchem/utils/test/test_differentiation_utils.py @@ -698,3 +698,45 @@ def test_dot(): z = torch.tensor([[5, 6], [7, 8]]) assert torch.allclose(dot(r, z), torch.tensor([[26, 44]])) assert torch.allclose(dot(r, z), sum(r * z)) + + +@pytest.mark.torch +def test_gmres(): + from deepchem.utils.differentiation_utils.solve import gmres + from deepchem.utils.differentiation_utils import MatrixLinearOperator + A = MatrixLinearOperator(torch.tensor([[1., 2], [3, 4]]), True) + B = torch.tensor([[5., 6], [7, 8]]) + expected_result = torch.tensor([[0.8959, 1.0697], [1.2543, 1.4263]]) + assert torch.allclose(gmres(A, B), expected_result, 0.001) + + +@pytest.mark.torch +def test_setup_linear_problem(): + from deepchem.utils.differentiation_utils import MatrixLinearOperator + from deepchem.utils.differentiation_utils.solve import setup_linear_problem + A = MatrixLinearOperator(torch.randn(4, 3, 3), True) + B = torch.randn(4, 3, 2) + A_fcn, AT_fcn, B_new, col_swapped = setup_linear_problem( + A, B, None, None, [4], None, False) + assert A_fcn(B).shape == torch.Size([4, 3, 2]) + + +@pytest.mark.torch +def test_safe_denom(): + from deepchem.utils.differentiation_utils.solve import safedenom + r = torch.tensor([[0., 2], [3, 4]]) + assert torch.allclose( + safedenom(r, 1e-9), + torch.tensor([[1.0000e-09, 2.0000e+00], [3.0000e+00, 4.0000e+00]])) + + +@pytest.mark.torch +def test_get_largest_eival(): + from deepchem.utils.differentiation_utils.solve import get_largest_eival + + def Afcn(x): + return 10 * x + + x = torch.tensor([[1., 2], [3, 4]]) + assert torch.allclose(get_largest_eival(Afcn, x), torch.tensor([[10., + 10.]])) diff --git a/docs/source/api_reference/utils.rst b/docs/source/api_reference/utils.rst index f8d1ee16db..cd048e635b 100644 --- a/docs/source/api_reference/utils.rst +++ b/docs/source/api_reference/utils.rst @@ -449,6 +449,14 @@ The utilites here are used to create an object that contains information about a .. autofunction:: deepchem.utils.differentiation_utils.solve.dot +.. autofunction:: deepchem.utils.differentiation_utils.solve.gmres + +.. autofunction:: deepchem.utils.differentiation_utils.solve.setup_linear_problem + +.. autofunction:: deepchem.utils.differentiation_utils.solve.safedenom + +.. autofunction:: deepchem.utils.differentiation_utils.solve.get_largest_eival + Attribute Utilities ------------------- From b8813a7d31df1f5718bc003409fd4d8c85056d4d Mon Sep 17 00:00:00 2001 From: GreatRSingh Date: Fri, 19 Jan 2024 17:52:31 +0530 Subject: [PATCH 5/5] Squashed commit of the following: commit 0a3ec8ff7cdba43faf19eebb2fc495a7da9417c3 Author: GreatRSingh Date: Fri Jan 19 17:33:57 2024 +0530 sibling test commit 2b83c5ffdb941d1069e7ebaa37828010faaa95a6 Author: GreatRSingh Date: Fri Jan 19 16:55:51 2024 +0530 transform tests --- .../differentiation_utils/pure_function.py | 18 ++++++++- deepchem/utils/pytorch_utils.py | 2 +- deepchem/utils/test/test_dft_utils.py | 39 +++++++++++++++++-- .../utils/test/test_differentiation_utils.py | 14 +++++++ 4 files changed, 67 insertions(+), 6 deletions(-) diff --git a/deepchem/utils/differentiation_utils/pure_function.py b/deepchem/utils/differentiation_utils/pure_function.py index c051a9dd41..2cc821f770 100644 --- a/deepchem/utils/differentiation_utils/pure_function.py +++ b/deepchem/utils/differentiation_utils/pure_function.py @@ -360,6 +360,7 @@ def _check_identical_objs(objs1: List, objs2: List) -> bool: return False return True + class SingleSiblingPureFunction(PureFunction): """Implementation of PureFunction for a sibling method @@ -465,7 +466,6 @@ def _set_all_obj_params(self, allobjparams: List): allobjparams[self.cumsum_idx[i]:self.cumsum_idx[i + 1]]) - def get_pure_function(fcn) -> PureFunction: """Get the pure function form of the function or method ``fcn``. @@ -530,6 +530,22 @@ def make_sibling(*pfuncs) -> Callable[[Callable], PureFunction]: Changing the state of the decorated function will also change the state of ``pfunc`` and its other siblings. + Examples + -------- + >>> import torch + >>> from deepchem.utils.differentiation_utils import make_sibling + >>> def fcn1(x, y): + ... return x + y + >>> def fcn2(x, y): + ... return x - y + >>> pfunc1 = get_pure_function(fcn1) + >>> pfunc2 = get_pure_function(fcn2) + >>> @make_sibling(pfunc1) + ... def fcn3(x, y): + ... return x * y + >>> pfunc3(1, 2) + 2 + Parameters ---------- pfuncs: List[Callable] diff --git a/deepchem/utils/pytorch_utils.py b/deepchem/utils/pytorch_utils.py index 702513bf1a..b87335b074 100644 --- a/deepchem/utils/pytorch_utils.py +++ b/deepchem/utils/pytorch_utils.py @@ -436,7 +436,7 @@ def to_fortran_order(V): "Only the last two dimensions can be made Fortran order.") -def get_np_dtype(dtype: torch.dtype) -> np.dtype: +def get_np_dtype(dtype: torch.dtype) -> Any: """corresponding numpy dtype from the input pytorch's tensor dtype Examples diff --git a/deepchem/utils/test/test_dft_utils.py b/deepchem/utils/test/test_dft_utils.py index c09984a73a..6ab7279193 100644 --- a/deepchem/utils/test/test_dft_utils.py +++ b/deepchem/utils/test/test_dft_utils.py @@ -459,7 +459,7 @@ def test_radial_grid(): def test_get_xw_integration(): from deepchem.utils.dft_utils import get_xw_integration x, w = get_xw_integration(4, "chebyshev") - assert x.shape == (4, ) + assert x.shape == (4,) assert w.shape == torch.Size([4]) @@ -472,6 +472,37 @@ def test_sliced_radial_grid(): @pytest.mark.torch -def test_base_grid_transform(): - from deepchem.utils.dft_utils import BaseGridTransform - +def test_de2_transform(): + from deepchem.utils.dft_utils import DE2Transformation + x = torch.linspace(-1, 1, 100) + r = DE2Transformation().x2r(x) + assert r.shape == torch.Size([100]) + drdx = DE2Transformation().get_drdx(x) + assert drdx.shape == torch.Size([100]) + + +@pytest.mark.torch +def test_logm3_transform(): + from deepchem.utils.dft_utils import LogM3Transformation + x = torch.linspace(-1, 1, 100) + r = LogM3Transformation().x2r(x) + assert r.shape == torch.Size([100]) + drdx = LogM3Transformation().get_drdx(x) + assert drdx.shape == torch.Size([100]) + + +@pytest.mark.torch +def test_treutlerm4_transform(): + from deepchem.utils.dft_utils import TreutlerM4Transformation + x = torch.linspace(-1, 1, 100) + r = TreutlerM4Transformation().x2r(x) + assert r.shape == torch.Size([100]) + drdx = TreutlerM4Transformation().get_drdx(x) + assert drdx.shape == torch.Size([100]) + + +@pytest.mark.torch +def test_get_grid_transform(): + from deepchem.utils.dft_utils import get_grid_transform + transform = get_grid_transform("logm3") + transform.x2r(torch.tensor([0.5])) == torch.tensor([2.]) diff --git a/deepchem/utils/test/test_differentiation_utils.py b/deepchem/utils/test/test_differentiation_utils.py index a17cc6d073..338bb0ac4a 100644 --- a/deepchem/utils/test/test_differentiation_utils.py +++ b/deepchem/utils/test/test_differentiation_utils.py @@ -643,6 +643,20 @@ def fcn(x, y): assert pfunc(1, 2) == 3 +@pytest.mark.torch +def test_make_siblings(): + from deepchem.utils.differentiation_utils import make_sibling + + def fcn1(x, y): + return x + y + + @make_sibling(fcn1) + def fcn3(x, y): + return x * y + + assert fcn3(1, 2) == 2 + + @pytest.mark.torch def test_wrap_gmres(): from deepchem.utils.differentiation_utils.solve import wrap_gmres