Skip to content

Commit

Permalink
Merge pull request #10 from marstaa/implement-slicing
Browse files Browse the repository at this point in the history
Implement slicing
  • Loading branch information
marstaa authored Sep 18, 2021
2 parents b43c227 + c669e2b commit 8615c02
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 0 deletions.
18 changes: 18 additions & 0 deletions pyspherex/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,21 @@ def __matmul__(self, other):
res += sum(coeff1 * coeff2.conjugate()
for coeff1, coeff2 in zip(self.coeffs[degree], other.coeffs[degree]))
return res

def __len__(self):
if self.coeffs:
return max(self.coeffs.keys()) + 1
return 0

def __getitem__(self, key):
if isinstance(key, list):
coeffs_new = {}
for degree in key:
if degree in self.coeffs:
coeffs_new[degree] = self.coeffs[degree].copy()
return Expansion(coeffs_new)
if isinstance(key, int):
return self[[key]]
if isinstance(key, slice):
return self[list(range(*key.indices(len(self.coeffs))))]
raise TypeError('`key` must be an index, a slice or a list of integers')
36 changes: 36 additions & 0 deletions tests/test_expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,3 +202,39 @@ def test_expansion_matmul():
exp2 = Expansion.from_data(phi, theta, data2, degree_max)
res = -4j * np.pi / 3
assert exp1 @ exp2 == approx(res, rel=1e-2)

def test_expansion_len():
"""Test length of expansion"""
exp = Expansion({})
assert len(exp) == 0

exp = Expansion({0: [1], 1: [2, 3, 4]})
assert len(exp) == 2

def test_expansion_slice():
"""Test slicing of expansion"""
degree_max = 10
coeffs = {degree: [np.random.normal() + 1j * np.random.normal()
for order in range(2 * degree + 1)]
for degree in range(degree_max + 1)}
exp = Expansion(coeffs)

sliced = exp[0]
assert len(sliced) == 1
assert sliced.coeffs[0] == coeffs[0]

sliced = exp[1:5]
assert len(sliced) == 5

sliced = exp[1:-1]
assert len(sliced) == degree_max

sliced = exp[1:]
assert len(sliced) == degree_max + 1

sliced = exp[1:9:2]
assert len(sliced) == 8
assert 1 in sliced.coeffs
assert 3 in sliced.coeffs
assert 5 in sliced.coeffs
assert 7 in sliced.coeffs

0 comments on commit 8615c02

Please sign in to comment.