Skip to content

Commit

Permalink
Discovering orbital pairs on the fly using ptr
Browse files Browse the repository at this point in the history
  • Loading branch information
pfebrer committed Mar 8, 2023
1 parent 63a685e commit 960ecfa
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 38 deletions.
66 changes: 35 additions & 31 deletions sisl/physics/_compute_dm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

@cython.boundscheck(False)
@cython.wraparound(False)
def add_cnc_diag_spin(state: complex_or_float[:, :], row_orbs: cython.int[:], col_orbs_uc: cython.int[:],
def add_cnc_diag_spin(state: complex_or_float[:, :], DM_ptr: cython.int[:], DM_col_uc: cython.int[:],
occs: cython.floating[:], DM_kpoint: complex_or_float[:], occtol: float = 1e-9):
"""Adds the cnc contributions of all orbital pairs to the DM given a array of states.
Expand All @@ -18,10 +18,10 @@ def add_cnc_diag_spin(state: complex_or_float[:, :], row_orbs: cython.int[:], co
state:
The coefficients of all eigenstates for this contribution.
Array of shape (n_eigenstates, n_basisorbitals)
row_orbs:
The orbital row indices of the sparsity pattern.
Shape (nnz, ), where nnz is the number of nonzero elements in the sparsity pattern.
col_orbs_uc:
DM_ptr:
The pointer to row array of the sparse DM.
Shape (no + 1, ), where no is the number of orbitals in the unit cell.
DM_col_uc:
The orbital col indices of the sparsity pattern, but converted to the unit cell.
Shape (nnz, ), where nnz is the number of nonzero elements in the sparsity pattern.
occs:
Expand All @@ -37,11 +37,13 @@ def add_cnc_diag_spin(state: complex_or_float[:, :], row_orbs: cython.int[:], co
i: cython.int
u: cython.int
v: cython.int
ipair: cython.int

# Number of orbitals in the unit cell
no: cython.int = DM_ptr.shape[0] - 1
ival: cython.int

# Loop lengths
n_wfs: cython.int = state.shape[0]
n_opairs: cython.int = row_orbs.shape[0]

# Variable to store the occupation of each state
occ: float
Expand All @@ -54,17 +56,17 @@ def add_cnc_diag_spin(state: complex_or_float[:, :], row_orbs: cython.int[:], co
if occ < occtol:
continue

# The occupation is above the tolerance threshold, loop through all overlaping orbital pairs
for ipair in range(n_opairs):
# Get the orbital indices of this pair
u = row_orbs[ipair]
v = col_orbs_uc[ipair]
# Add the contribution of this eigenstate to the DM_{u,v} element
DM_kpoint[ipair] = DM_kpoint[ipair] + state[i, u] * occ * state[i, v].conjugate()
# Loop over all non zero elements in the sparsity pattern
for u in range(no):
for ival in range(DM_ptr[u], DM_ptr[u+1]):
v = DM_col_uc[ival]
# Add the contribution of this eigenstate to the DM_{u,v} element
DM_kpoint[ival] = DM_kpoint[ival] + state[i, u] * occ * state[i, v].conjugate()


@cython.boundscheck(False)
@cython.wraparound(False)
def add_cnc_nc(state: cython.complex[:, :, :], row_orbs: cython.int[:], col_orbs_uc: cython.int[:],
def add_cnc_nc(state: cython.complex[:, :, :], DM_ptr: cython.int[:], DM_col_uc: cython.int[:],
occs: cython.floating[:], DM_kpoint: cython.complex[:, :, :], occtol: float = 1e-9):
"""Adds the cnc contributions of all orbital pairs to the DM given a array of states.
Expand All @@ -76,10 +78,10 @@ def add_cnc_nc(state: cython.complex[:, :, :], row_orbs: cython.int[:], col_orbs
The coefficients of all eigenstates for this contribution.
Array of shape (n_eigenstates, n_basisorbitals, 2), where the last dimension is the spin
"up"/"down" dimension.
row_orbs:
The orbital row indices of the sparsity pattern.
Shape (nnz, ), where nnz is the number of nonzero elements in the sparsity pattern.
col_orbs_uc:
DM_ptr:
The pointer to row array of the sparse DM.
Shape (no + 1, ), where no is the number of orbitals in the unit cell.
DM_col_uc:
The orbital col indices of the sparsity pattern, but converted to the unit cell.
Shape (nnz, ), where nnz is the number of nonzero elements in the sparsity pattern.
occs:
Expand All @@ -96,14 +98,17 @@ def add_cnc_nc(state: cython.complex[:, :, :], row_orbs: cython.int[:], col_orbs
i: cython.int
u: cython.int
v: cython.int
ipair: cython.int
ival: cython.int

# Number of orbitals in the unit cell
no: cython.int = DM_ptr.shape[0] - 1

# The spin box indices.
Di: cython.int
Dj: cython.int

# Loop lengths
n_wfs: cython.int = state.shape[0]
n_opairs: cython.int = row_orbs.shape[0]

# Variable to store the occupation of each state
occ: float
Expand All @@ -115,14 +120,13 @@ def add_cnc_nc(state: cython.complex[:, :, :], row_orbs: cython.int[:], col_orbs
# If the occupation is lower than the tolerance, skip the state
if occ < occtol:
continue

# The occupation is above the tolerance threshold, loop through all overlaping orbital pairs
for ipair in range(n_opairs):
# Get the orbital indices of this pair
u = row_orbs[ipair]
v = col_orbs_uc[ipair]

# Loop over all non zero elements in the sparsity pattern
for u in range(no):
for ival in range(DM_ptr[u], DM_ptr[u+1]):
v = DM_col_uc[ival]

# Add to spin box
for Di in range(2):
for Dj in range(2):
DM_kpoint[ipair, Di, Dj] = DM_kpoint[ipair, Di, Dj] + state[i, u, Di] * occ * state[i, v, Dj].conjugate()
# Add to spin box
for Di in range(2):
for Dj in range(2):
DM_kpoint[ival, Di, Dj] = DM_kpoint[ival, Di, Dj] + state[i, u, Di] * occ * state[i, v, Dj].conjugate()
12 changes: 5 additions & 7 deletions sisl/physics/compute_dm.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,7 @@ def compute_dm(bz: BrillouinZone, eigenstates: Optional[Sequence[EigenstateElect
geom = H.geometry

# Sparsity pattern information
row_orbs, col_orbs = H.nonzero()
col_orbs_uc = H.osc2uc(col_orbs)
col_isc = col_orbs // H.no
col_isc, col_uc = np.divmod(H._csr.col, H.no)
sc_offsets = H.sc_off.dot(H.cell)

# Initialize the density matrix using the sparsity pattern of the Hamiltonian.
Expand Down Expand Up @@ -121,8 +119,8 @@ def compute_dm(bz: BrillouinZone, eigenstates: Optional[Sequence[EigenstateElect

if DM.spin.is_diagonal:
# Calculate the matrix elements contributions for this k point.
DM_kpoint = np.zeros(row_orbs.shape[0], dtype=k_eigs.state.dtype)
add_cnc_diag_spin(state, row_orbs, col_orbs_uc, occs, DM_kpoint, occtol=occtol)
DM_kpoint = np.zeros(DM.nnz, dtype=k_eigs.state.dtype)
add_cnc_diag_spin(state, H._csr.ptr, col_uc, occs, DM_kpoint, occtol=occtol)

# Apply phases
DM_kpoint = DM_kpoint * phases
Expand All @@ -139,8 +137,8 @@ def compute_dm(bz: BrillouinZone, eigenstates: Optional[Sequence[EigenstateElect

# Calculate the matrix elements contributions for this k point. For each matrix element
# we allocate a 2x2 spin box.
DM_kpoint = np.zeros((row_orbs.shape[0], 2, 2), dtype=np.complex128)
add_cnc_nc(state, row_orbs, col_orbs_uc, occs, DM_kpoint, occtol=occtol)
DM_kpoint = np.zeros((DM.nnz, 2, 2), dtype=np.complex128)
add_cnc_nc(state, H._csr.ptr, col_uc, occs, DM_kpoint, occtol=occtol)

# Apply phases
DM_kpoint *= phases.reshape(-1, 1, 1)
Expand Down

0 comments on commit 960ecfa

Please sign in to comment.