Skip to content

Commit

Permalink
Merge pull request #652 from zerothi/650-slicing
Browse files Browse the repository at this point in the history
added sanitize to SparseCSR and __setitem__
  • Loading branch information
zerothi authored Nov 9, 2023
2 parents 20dd22f + 1a3b788 commit 82c1e43
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 23 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ we hit release version 1.0.0.

## [0.14.4] - YYYY-MM-DD

### Fixed
- enabled slicing in matrix assignments, #650



## [0.14.3] - 2023-11-07
Expand Down
23 changes: 23 additions & 0 deletions src/sisl/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,18 @@ def _(self, atoms: ndarray) -> ndarray:
return np.flatnonzero(atoms)
return atoms

@_sanitize_atoms.register
def _(self, atoms: slice) -> ndarray:
# TODO consider doing range(self.na)[atoms]
start, stop, step = atoms.start, atoms.stop, atoms.step
if start is None:
start = 0
if stop is None:
stop = self.na
if step is None:
step = 1
return np.arange(start, stop, step)

@_sanitize_atoms.register
def _(self, atoms: str) -> ndarray:
return self.names[atoms]
Expand Down Expand Up @@ -449,6 +461,17 @@ def _(self, orbitals: str) -> ndarray:
atoms = self._sanitize_atoms(orbitals)
return self.a2o(atoms, all=True)

@_sanitize_orbs.register
def _(self, orbitals: slice) -> ndarray:
start, stop, step = orbitals.start, orbitals.stop, orbitals.step
if start is None:
start = 0
if stop is None:
stop = self.na
if step is None:
step = 1
return np.arange(start, stop, step)

@_sanitize_orbs.register
def _(self, orbitals: AtomCategory) -> ndarray:
atoms = self._sanitize_atoms(orbitals)
Expand Down
2 changes: 0 additions & 2 deletions src/sisl/io/orca/tests/test_stdout.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,6 @@ def test_charge_orbital_reduced_unpol(sisl_files):
assert S is None


@pytest.mark.only
def test_charge_orbital_full_unpol(sisl_files):
f = sisl_files(_dir, "molecule2.output")
out = stdoutSileORCA(f)
Expand All @@ -284,7 +283,6 @@ def test_charge_orbital_full_unpol(sisl_files):
assert S is None


@pytest.mark.only
def test_read_energy(sisl_files):
f = sisl_files(_dir, "molecule.output")
out = stdoutSileORCA(f)
Expand Down
72 changes: 53 additions & 19 deletions src/sisl/sparse.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at https://mozilla.org/MPL/2.0/.
from functools import reduce
from functools import reduce, singledispatchmethod
from numbers import Integral

import numpy as np
Expand All @@ -15,6 +15,7 @@
argsort,
asarray,
atleast_1d,
bool_,
broadcast,
concatenate,
copyto,
Expand Down Expand Up @@ -512,6 +513,40 @@ def finalize(self, sort=True):
# Signal that we indeed have finalized the data
self._finalized = sort

@singledispatchmethod
def _sanitize(self, idx, axis=0) -> ndarray:
"""Sanitize the input indices to a conforming numpy array"""
if idx is None:
if axis < 0:
return _a.arangei(np.max(self.shape))
return _a.arangei(self.shape[axis])
idx = _a.asarrayi(idx)
if idx.size == 0:
return _a.asarrayi([])
elif idx.dtype == bool_:
return idx.nonzero()[0].astype(np.int32)
return idx

@_sanitize.register
def _(self, idx: ndarray, axis=0) -> ndarray:
if idx.dtype == bool_:
return np.flatnonzero(idx).astype(np.int32)
return idx.astype(np.int32, copy=False)

@_sanitize.register
def _(self, idx: slice, axis=0) -> ndarray:
start, stop, step = idx.start, idx.stop, idx.step
if start is None:
start = 0
if stop is None:
if axis < 0:
stop = np.max(self.shape)
else:
stop = self.shape[axis]
if step is None:
step = 1
return _a.arangei(start, stop, step)

def edges(self, row, exclude=None):
"""Retrieve edges (connections) of a given `row` or list of `row`'s
Expand All @@ -524,11 +559,11 @@ def edges(self, row, exclude=None):
exclude : int or list of int, optional
remove edges which are in the `exclude` list.
"""
row = unique(_a.asarrayi(row))
row = unique(self._sanitize(row))
if exclude is None:
exclude = []
else:
exclude = unique(_a.asarrayi(exclude))
exclude = unique(self._sanitize(exclude))

# Now get the edges
ptr = self.ptr
Expand Down Expand Up @@ -558,7 +593,7 @@ def delete_columns(self, columns, keep_shape=False):
cnz = count_nonzero

# Sort the columns
columns = unique(_a.asarrayi(columns))
columns = unique(self._sanitize(columns, axis=1))
n_cols = cnz(columns < self.shape[1])

# Grab pointers
Expand Down Expand Up @@ -674,8 +709,8 @@ def translate_columns(self, old, new, rows=None, clean=True):
clean : bool, optional
whether the new translated columns, outside the shape, should be deleted or not (default delete)
"""
old = _a.asarrayi(old)
new = _a.asarrayi(new)
old = self._sanitize(old, axis=1)
new = self._sanitize(new, axis=1)

if len(old) != len(new):
raise ValueError(
Expand Down Expand Up @@ -730,7 +765,7 @@ def scale_columns(self, cols, scale, rows=None):
rows : int or array_like, optional
only scale the column values that exists in these rows, default to all
"""
cols = _a.asarrayi(cols)
cols = self._sanitize(cols, axis=1)

if np_any(cols >= self.shape[1]):
raise ValueError(
Expand Down Expand Up @@ -857,7 +892,7 @@ def iter_nnz(self, row=None):
for c in self.col[ptr : ptr + n]:
yield r, c
else:
for r in _a.asarrayi(row).ravel():
for r in self._sanitize(row).ravel():
n = self.ncol[r]
ptr = self.ptr[r]
for c in self.col[ptr : ptr + n]:
Expand Down Expand Up @@ -899,6 +934,7 @@ def _extend(self, i, j, ret_indices=True):
IndexError
for indices out of bounds
"""
i = self._sanitize(i)
if asarray(i).size == 0:
return arrayi([])
if i < 0 or i >= self.shape[0]:
Expand All @@ -910,7 +946,7 @@ def _extend(self, i, j, ret_indices=True):
# " must only be performed at one row-element at a time.\n"
# "However, multiple columns at a time are allowed.")
# Ensure flattened array...
j = asarrayi(j).ravel()
j = self._sanitize(j, axis=1).ravel()
if len(j) == 0:
return arrayi([])
if np_any(j < 0) or np_any(j >= self.shape[1]):
Expand Down Expand Up @@ -1063,8 +1099,7 @@ def _get(self, i, j):
numpy.ndarray
indices of the existing elements
"""
# Ensure flattened array...
j = asarrayi(j)
j = self._sanitize(j, axis=1)

# Make it a little easier
ptr = self.ptr[i]
Expand All @@ -1088,8 +1123,7 @@ def _get_only(self, i, j):
numpy.ndarray
indices of existing elements
"""
# Ensure flattened array...
j = asarrayi(j).ravel()
j = self._sanitize(j, axis=1).ravel()

# Make it a little easier
ptr = self.ptr[i]
Expand Down Expand Up @@ -1208,9 +1242,9 @@ def __setitem__(self, key, data):
data[isnan(data)] = 0

# Determine how the indices should work
i = key[0]
j = key[1]
if isinstance(i, (list, ndarray)) and isinstance(j, (list, ndarray)):
i = self._sanitize(key[0])
j = self._sanitize(key[1], axis=1)
if i.size > 1 and isinstance(j, (list, ndarray)):
# Create a b-cast object to iterate
# Note that this does not do the actual b-casting and thus
# we can iterate and operate as though it was an actual array
Expand Down Expand Up @@ -1335,7 +1369,7 @@ def nonzero(self, rows=None, only_cols=False):
idx = (ncol > 0).nonzero()[0]
rows = repeat(idx.astype(int32, copy=False), ncol[idx])
else:
rows = _a.asarray(rows).ravel()
rows = self._sanitize(rows).ravel()
ncol = ncol[rows]
cols = col[array_arange(ptr[rows], n=ncol, dtype=int32)]
if not only_cols:
Expand Down Expand Up @@ -1562,7 +1596,7 @@ def remove(self, indices):
indices : array_like
the indices of the rows *and* columns that are removed in the sparse pattern
"""
indices = asarrayi(indices)
indices = self._sanitize(indices, axis=-1)

# Check if we have a square matrix or a rectangular one
if self.shape[0] >= self.shape[1]:
Expand All @@ -1583,7 +1617,7 @@ def sub(self, indices):
indices : array_like
the indices of the rows *and* columns that are retained in the sparse pattern
"""
indices = asarrayi(indices).ravel()
indices = self._sanitize(indices, axis=-1).ravel()

# Check if we have a square matrix or a rectangular one
if self.shape[0] == self.shape[1]:
Expand Down
10 changes: 8 additions & 2 deletions src/sisl/sparse_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -1225,14 +1225,17 @@ def __setitem__(self, key, val):
Override set item for slicing operations and enables easy
setting of parameters in a sparse matrix
"""
dd = self._def_dim
if len(key) > 2:
# This may be a specification of supercell indices
if isinstance(key[-1], tuple):
# We guess it is the supercell index
off = self.geometry.sc_index(key[-1]) * self.na
key = [el for el in key[:-1]]
key[1] = self.geometry.sc2uc(key[1]) + off
key = tuple(
self.geometry._sanitize_atoms(k) if i < 2 else k for i, k in enumerate(key)
)
dd = self._def_dim
if dd >= 0:
key = tuple(key) + (dd,)
self._def_dim = -1
Expand Down Expand Up @@ -1648,14 +1651,17 @@ def __setitem__(self, key, val):
Override set item for slicing operations and enables easy
setting of parameters in a sparse matrix
"""
dd = self._def_dim
if len(key) > 2:
# This may be a specification of supercell indices
if isinstance(key[-1], tuple):
# We guess it is the supercell index
off = self.geometry.sc_index(key[-1]) * self.no
key = [el for el in key[:-1]]
key[1] = self.geometry.osc2uc(key[1]) + off
key = tuple(
self.geometry._sanitize_orbs(k) if i < 2 else k for i, k in enumerate(key)
)
dd = self._def_dim
if dd >= 0:
key = tuple(key) + (dd,)
self._def_dim = -1
Expand Down
13 changes: 13 additions & 0 deletions src/sisl/tests/test_sparse_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,12 @@ def test_numpy_reduction(self, setup):
S.finalize()
assert np.sum(S, axis=(0, 1)) == pytest.approx(1 * 2 * 2)

def test_sanitize_atoms_assign(self, setup):
g = graphene(atoms=Atom(6, R=1.43))
S = SparseAtom(g)
for i in range(2):
S[i, 1:4] = 1

def test_fromsp1(self, setup):
g = setup.g.repeat(2, 0).tile(2, 1)
lil = sc.sparse.lil_matrix((g.na, g.na_s), dtype=np.int32)
Expand Down Expand Up @@ -769,3 +775,10 @@ def test_translate_sparse_atoms():
assert transl_both[1, 1] == 2
assert transl_both[0, 1] == 0
assert transl_both[0, 3] == 3


def test_sanitize_orbs_assign():
g = graphene(atoms=Atom(6, R=[1.43, 1.66]))
S = SparseOrbital(g)
for i in range(2):
S[i, 1:4] = 1

0 comments on commit 82c1e43

Please sign in to comment.