Skip to content

Commit

Permalink
adding atom_index field and simplified basis creation
Browse files Browse the repository at this point in the history
  • Loading branch information
hatemhelal committed Oct 1, 2024
1 parent 48379e7 commit 8551772
Showing 1 changed file with 10 additions and 29 deletions.
39 changes: 10 additions & 29 deletions mess/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from mess.primitive import Primitive
from mess.structure import Structure
from mess.types import (
Float3,
FloatN,
FloatNx3,
FloatNxM,
Expand Down Expand Up @@ -56,6 +55,7 @@ def fixer(x):

df = pd.DataFrame()
df["orbital"] = self.orbital_index
df["atom"] = self.primitives.atom_index
df["coefficient"] = self.coefficients
df["norm"] = self.primitives.norm
df["center"] = fixer(self.primitives.center)
Expand Down Expand Up @@ -105,13 +105,18 @@ def basisset(structure: Structure, basis_name: str = "sto-3g") -> Basis:
Basis constructed from inputs
"""
orbitals = []
atom_index = []

for a in range(structure.num_atoms):
element = int(structure.atomic_number[a])
center = structure.position[a, :]
orbitals += _build_orbitals(basis_name, element, center)
for atom_id in range(structure.num_atoms):
element = int(structure.atomic_number[atom_id])
out = _bse_to_orbitals(basis_name, element)
atom_index.extend([atom_id] * sum(len(ao.primitives) for ao in out))
orbitals += out

primitives, coefficients, orbital_index = batch_orbitals(orbitals)
primitives = eqx.tree_at(lambda p: p.atom_index, primitives, jnp.array(atom_index))
center = structure.position[primitives.atom_index, :]
primitives = eqx.tree_at(lambda p: p.center, primitives, center)

return Basis(
orbitals=orbitals,
Expand Down Expand Up @@ -175,30 +180,6 @@ def _bse_to_orbitals(basis_name: str, atomic_number: int) -> Tuple[Orbital]:
return tuple(orbitals)


def _build_orbitals(
basis_name: str, atomic_number: int, center: Float3
) -> Tuple[Orbital]:
"""
Constructs a tuple of Orbital objects for a given atomic_number and basis set,
with each orbital centered at the specified coordinates.
Args:
basis_name (str): The name of the basis set to use.
atomic_number (int): The atomic number used to build the orbitals.
center (Float3): the 3D coordinate specifying the center of the orbitals
Returns:
Tuple[Orbital]: A tuple of Orbitals centered at the provided coordinates.
"""
orbitals = _bse_to_orbitals(basis_name, atomic_number)

def where(orbitals):
return [p.center for ao in orbitals for p in ao.primitives]

num_centers = len(where(orbitals))
return eqx.tree_at(where, orbitals, replace=np.tile(center, (num_centers, 1)))


def basis_iter(basis: Basis):
from jax import tree

Expand Down

0 comments on commit 8551772

Please sign in to comment.