-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
12 changed files
with
140 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
from typing import Any | ||
|
||
from matrix_utils import SparseMatrixDict | ||
from pydantic import BaseModel | ||
|
||
|
||
class RestrictionsValidator(BaseModel): | ||
restrictions: dict[tuple[str, ...], list[tuple[str, ...]]] | ||
|
||
|
||
class RestrictedSparseMatrixDict(SparseMatrixDict): | ||
def __init__(self, restrictions: dict, *args, **kwargs): | ||
"""Like SparseMatrixDict, but follows `restrictions` on what can be multiplied. | ||
Only for use with normalization and weighting.""" | ||
super().__init__(*args, **kwargs) | ||
RestrictionsValidator(restrictions=restrictions) | ||
self._restrictions = restrictions | ||
|
||
def __matmul__(self, other: Any) -> SparseMatrixDict: | ||
"""Define logic for `@` matrix multiplication operator. | ||
Note that the sparse matrix dict must come first, i.e. `self @ other`. | ||
""" | ||
if isinstance(other, (SparseMatrixDict, RestrictedSparseMatrixDict)): | ||
return SparseMatrixDict( | ||
{ | ||
(a, *b): c @ d | ||
for a, c in self.items() | ||
for b, d in other.items() | ||
if b[0] in self._restrictions[a] | ||
} | ||
) | ||
else: | ||
return super().__matmul__(other) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import pytest | ||
from matrix_utils import SparseMatrixDict | ||
from pydantic import ValidationError | ||
|
||
from bw2calc.restricted_sparse_matrix_dict import RestrictedSparseMatrixDict, RestrictionsValidator | ||
|
||
|
||
class Dummy: | ||
def __init__(self, a): | ||
self.a = a | ||
|
||
def __matmul__(self, other): | ||
return self.a + other | ||
|
||
|
||
def test_restricted_sparse_matrix_dict(): | ||
smd = SparseMatrixDict({(("one",), "foo"): 1, (("two",), "bar"): 2}) | ||
rsmd = RestrictedSparseMatrixDict( | ||
{("seven",): [("one",)], ("eight",): [("two",)]}, | ||
{("seven",): Dummy(7), ("eight",): Dummy(8)}, | ||
) | ||
|
||
result = rsmd @ smd | ||
assert isinstance(result, SparseMatrixDict) | ||
assert len(result) == 2 | ||
assert result[(("seven",), ("one",), "foo")] == 8 | ||
assert result[(("eight",), ("two",), "bar")] == 10 | ||
|
||
|
||
def test_restrictions_validator(): | ||
assert RestrictionsValidator(restrictions={("seven",): [("one",)], ("eight",): [("two",)]}) | ||
with pytest.raises(ValidationError): | ||
RestrictionsValidator(restrictions={"seven": [("one",)], ("eight",): [("two",)]}) |