Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added sanitize to SparseCSR and __setitem__ #652

Merged
merged 1 commit into from
Nov 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ we hit release version 1.0.0.

## [0.14.4] - YYYY-MM-DD

### Fixed
- enabled slicing in matrix assignments, #650



## [0.14.3] - 2023-11-07
Expand Down
23 changes: 23 additions & 0 deletions src/sisl/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,18 @@
return np.flatnonzero(atoms)
return atoms

@_sanitize_atoms.register
def _(self, atoms: slice) -> ndarray:
# TODO consider doing range(self.na)[atoms]
start, stop, step = atoms.start, atoms.stop, atoms.step
if start is None:
start = 0

Check warning on line 391 in src/sisl/geometry.py

View check run for this annotation

Codecov / codecov/patch

src/sisl/geometry.py#L391

Added line #L391 was not covered by tests
if stop is None:
stop = self.na

Check warning on line 393 in src/sisl/geometry.py

View check run for this annotation

Codecov / codecov/patch

src/sisl/geometry.py#L393

Added line #L393 was not covered by tests
if step is None:
step = 1
return np.arange(start, stop, step)

@_sanitize_atoms.register
def _(self, atoms: str) -> ndarray:
return self.names[atoms]
Expand Down Expand Up @@ -449,6 +461,17 @@
atoms = self._sanitize_atoms(orbitals)
return self.a2o(atoms, all=True)

@_sanitize_orbs.register
def _(self, orbitals: slice) -> ndarray:
start, stop, step = orbitals.start, orbitals.stop, orbitals.step
if start is None:
start = 0

Check warning on line 468 in src/sisl/geometry.py

View check run for this annotation

Codecov / codecov/patch

src/sisl/geometry.py#L468

Added line #L468 was not covered by tests
if stop is None:
stop = self.na

Check warning on line 470 in src/sisl/geometry.py

View check run for this annotation

Codecov / codecov/patch

src/sisl/geometry.py#L470

Added line #L470 was not covered by tests
if step is None:
step = 1
return np.arange(start, stop, step)

@_sanitize_orbs.register
def _(self, orbitals: AtomCategory) -> ndarray:
atoms = self._sanitize_atoms(orbitals)
Expand Down
2 changes: 0 additions & 2 deletions src/sisl/io/orca/tests/test_stdout.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,6 @@ def test_charge_orbital_reduced_unpol(sisl_files):
assert S is None


@pytest.mark.only
def test_charge_orbital_full_unpol(sisl_files):
f = sisl_files(_dir, "molecule2.output")
out = stdoutSileORCA(f)
Expand All @@ -284,7 +283,6 @@ def test_charge_orbital_full_unpol(sisl_files):
assert S is None


@pytest.mark.only
def test_read_energy(sisl_files):
f = sisl_files(_dir, "molecule.output")
out = stdoutSileORCA(f)
Expand Down
72 changes: 53 additions & 19 deletions src/sisl/sparse.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at https://mozilla.org/MPL/2.0/.
from functools import reduce
from functools import reduce, singledispatchmethod
from numbers import Integral

import numpy as np
Expand All @@ -15,6 +15,7 @@
argsort,
asarray,
atleast_1d,
bool_,
broadcast,
concatenate,
copyto,
Expand Down Expand Up @@ -512,6 +513,40 @@
# Signal that we indeed have finalized the data
self._finalized = sort

@singledispatchmethod
def _sanitize(self, idx, axis=0) -> ndarray:
"""Sanitize the input indices to a conforming numpy array"""
if idx is None:
if axis < 0:
return _a.arangei(np.max(self.shape))
return _a.arangei(self.shape[axis])

Check warning on line 522 in src/sisl/sparse.py

View check run for this annotation

Codecov / codecov/patch

src/sisl/sparse.py#L520-L522

Added lines #L520 - L522 were not covered by tests
idx = _a.asarrayi(idx)
if idx.size == 0:
return _a.asarrayi([])
elif idx.dtype == bool_:
return idx.nonzero()[0].astype(np.int32)

Check warning on line 527 in src/sisl/sparse.py

View check run for this annotation

Codecov / codecov/patch

src/sisl/sparse.py#L527

Added line #L527 was not covered by tests
return idx

@_sanitize.register
def _(self, idx: ndarray, axis=0) -> ndarray:
if idx.dtype == bool_:
return np.flatnonzero(idx).astype(np.int32)

Check warning on line 533 in src/sisl/sparse.py

View check run for this annotation

Codecov / codecov/patch

src/sisl/sparse.py#L533

Added line #L533 was not covered by tests
return idx.astype(np.int32, copy=False)

@_sanitize.register
def _(self, idx: slice, axis=0) -> ndarray:
start, stop, step = idx.start, idx.stop, idx.step
if start is None:
start = 0
if stop is None:
if axis < 0:
stop = np.max(self.shape)

Check warning on line 543 in src/sisl/sparse.py

View check run for this annotation

Codecov / codecov/patch

src/sisl/sparse.py#L538-L543

Added lines #L538 - L543 were not covered by tests
else:
stop = self.shape[axis]
if step is None:
step = 1
return _a.arangei(start, stop, step)

Check warning on line 548 in src/sisl/sparse.py

View check run for this annotation

Codecov / codecov/patch

src/sisl/sparse.py#L545-L548

Added lines #L545 - L548 were not covered by tests

def edges(self, row, exclude=None):
"""Retrieve edges (connections) of a given `row` or list of `row`'s

Expand All @@ -524,11 +559,11 @@
exclude : int or list of int, optional
remove edges which are in the `exclude` list.
"""
row = unique(_a.asarrayi(row))
row = unique(self._sanitize(row))
if exclude is None:
exclude = []
else:
exclude = unique(_a.asarrayi(exclude))
exclude = unique(self._sanitize(exclude))

# Now get the edges
ptr = self.ptr
Expand Down Expand Up @@ -558,7 +593,7 @@
cnz = count_nonzero

# Sort the columns
columns = unique(_a.asarrayi(columns))
columns = unique(self._sanitize(columns, axis=1))
n_cols = cnz(columns < self.shape[1])

# Grab pointers
Expand Down Expand Up @@ -674,8 +709,8 @@
clean : bool, optional
whether the new translated columns, outside the shape, should be deleted or not (default delete)
"""
old = _a.asarrayi(old)
new = _a.asarrayi(new)
old = self._sanitize(old, axis=1)
new = self._sanitize(new, axis=1)

if len(old) != len(new):
raise ValueError(
Expand Down Expand Up @@ -730,7 +765,7 @@
rows : int or array_like, optional
only scale the column values that exists in these rows, default to all
"""
cols = _a.asarrayi(cols)
cols = self._sanitize(cols, axis=1)

if np_any(cols >= self.shape[1]):
raise ValueError(
Expand Down Expand Up @@ -857,7 +892,7 @@
for c in self.col[ptr : ptr + n]:
yield r, c
else:
for r in _a.asarrayi(row).ravel():
for r in self._sanitize(row).ravel():
n = self.ncol[r]
ptr = self.ptr[r]
for c in self.col[ptr : ptr + n]:
Expand Down Expand Up @@ -899,6 +934,7 @@
IndexError
for indices out of bounds
"""
i = self._sanitize(i)
if asarray(i).size == 0:
return arrayi([])
if i < 0 or i >= self.shape[0]:
Expand All @@ -910,7 +946,7 @@
# " must only be performed at one row-element at a time.\n"
# "However, multiple columns at a time are allowed.")
# Ensure flattened array...
j = asarrayi(j).ravel()
j = self._sanitize(j, axis=1).ravel()
if len(j) == 0:
return arrayi([])
if np_any(j < 0) or np_any(j >= self.shape[1]):
Expand Down Expand Up @@ -1063,8 +1099,7 @@
numpy.ndarray
indices of the existing elements
"""
# Ensure flattened array...
j = asarrayi(j)
j = self._sanitize(j, axis=1)

# Make it a little easier
ptr = self.ptr[i]
Expand All @@ -1088,8 +1123,7 @@
numpy.ndarray
indices of existing elements
"""
# Ensure flattened array...
j = asarrayi(j).ravel()
j = self._sanitize(j, axis=1).ravel()

# Make it a little easier
ptr = self.ptr[i]
Expand Down Expand Up @@ -1208,9 +1242,9 @@
data[isnan(data)] = 0

# Determine how the indices should work
i = key[0]
j = key[1]
if isinstance(i, (list, ndarray)) and isinstance(j, (list, ndarray)):
i = self._sanitize(key[0])
j = self._sanitize(key[1], axis=1)
if i.size > 1 and isinstance(j, (list, ndarray)):
# Create a b-cast object to iterate
# Note that this does not do the actual b-casting and thus
# we can iterate and operate as though it was an actual array
Expand Down Expand Up @@ -1335,7 +1369,7 @@
idx = (ncol > 0).nonzero()[0]
rows = repeat(idx.astype(int32, copy=False), ncol[idx])
else:
rows = _a.asarray(rows).ravel()
rows = self._sanitize(rows).ravel()
ncol = ncol[rows]
cols = col[array_arange(ptr[rows], n=ncol, dtype=int32)]
if not only_cols:
Expand Down Expand Up @@ -1562,7 +1596,7 @@
indices : array_like
the indices of the rows *and* columns that are removed in the sparse pattern
"""
indices = asarrayi(indices)
indices = self._sanitize(indices, axis=-1)

# Check if we have a square matrix or a rectangular one
if self.shape[0] >= self.shape[1]:
Expand All @@ -1583,7 +1617,7 @@
indices : array_like
the indices of the rows *and* columns that are retained in the sparse pattern
"""
indices = asarrayi(indices).ravel()
indices = self._sanitize(indices, axis=-1).ravel()

# Check if we have a square matrix or a rectangular one
if self.shape[0] == self.shape[1]:
Expand Down
10 changes: 8 additions & 2 deletions src/sisl/sparse_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -1225,14 +1225,17 @@ def __setitem__(self, key, val):
Override set item for slicing operations and enables easy
setting of parameters in a sparse matrix
"""
dd = self._def_dim
if len(key) > 2:
# This may be a specification of supercell indices
if isinstance(key[-1], tuple):
# We guess it is the supercell index
off = self.geometry.sc_index(key[-1]) * self.na
key = [el for el in key[:-1]]
key[1] = self.geometry.sc2uc(key[1]) + off
key = tuple(
self.geometry._sanitize_atoms(k) if i < 2 else k for i, k in enumerate(key)
)
dd = self._def_dim
if dd >= 0:
key = tuple(key) + (dd,)
self._def_dim = -1
Expand Down Expand Up @@ -1648,14 +1651,17 @@ def __setitem__(self, key, val):
Override set item for slicing operations and enables easy
setting of parameters in a sparse matrix
"""
dd = self._def_dim
if len(key) > 2:
# This may be a specification of supercell indices
if isinstance(key[-1], tuple):
# We guess it is the supercell index
off = self.geometry.sc_index(key[-1]) * self.no
key = [el for el in key[:-1]]
key[1] = self.geometry.osc2uc(key[1]) + off
key = tuple(
self.geometry._sanitize_orbs(k) if i < 2 else k for i, k in enumerate(key)
)
dd = self._def_dim
if dd >= 0:
key = tuple(key) + (dd,)
self._def_dim = -1
Expand Down
13 changes: 13 additions & 0 deletions src/sisl/tests/test_sparse_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,12 @@ def test_numpy_reduction(self, setup):
S.finalize()
assert np.sum(S, axis=(0, 1)) == pytest.approx(1 * 2 * 2)

def test_sanitize_atoms_assign(self, setup):
g = graphene(atoms=Atom(6, R=1.43))
S = SparseAtom(g)
for i in range(2):
S[i, 1:4] = 1

def test_fromsp1(self, setup):
g = setup.g.repeat(2, 0).tile(2, 1)
lil = sc.sparse.lil_matrix((g.na, g.na_s), dtype=np.int32)
Expand Down Expand Up @@ -769,3 +775,10 @@ def test_translate_sparse_atoms():
assert transl_both[1, 1] == 2
assert transl_both[0, 1] == 0
assert transl_both[0, 3] == 3


def test_sanitize_orbs_assign():
g = graphene(atoms=Atom(6, R=[1.43, 1.66]))
S = SparseOrbital(g)
for i in range(2):
S[i, 1:4] = 1