-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
213 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
from scipy import sparse | ||
|
||
from .lca import LCA | ||
from .result_cache import ResultCache | ||
|
||
|
||
class CachingLCA(LCA): | ||
"""Custom class which caches supply vectors. | ||
Cache resets upon iteration. If you do weird stuff outside of iteration you should probably | ||
use the regular LCA class.""" | ||
|
||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.cache = ResultCache() | ||
|
||
def __next__(self) -> None: | ||
self.cache.reset() | ||
super().__next__(self) | ||
|
||
def lci_calculation(self) -> None: | ||
"""The actual LCI calculation. | ||
Separated from ``lci`` to be reusable in cases where the matrices are already built, e.g. | ||
``redo_lci`` and Monte Carlo classes. | ||
""" | ||
if hasattr(self, "cache") and len(self.demand) == 1: | ||
key, value = list(self.demand.items())[0] | ||
try: | ||
self.supply_array = self.cache[key] * value | ||
except KeyError: | ||
self.supply_array = self.solve_linear_system() | ||
self.cache.add(key, self.supply_array.reshape((-1, 1)) / value) | ||
else: | ||
self.supply_array = self.solve_linear_system() | ||
# Turn 1-d array into diagonal matrix | ||
count = len(self.dicts.activity) | ||
self.inventory = self.biosphere_matrix @ sparse.spdiags( | ||
[self.supply_array], [0], count, count | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
import math | ||
from collections.abc import Mapping | ||
from typing import List | ||
|
||
import numpy as np | ||
|
||
|
||
class ResultCache(Mapping): | ||
def __init__(self, block_size: int = 100): | ||
"""This class allows supply vector results to be cached.""" | ||
self.next_index = 0 | ||
self.block_size = block_size | ||
self.indices = dict() | ||
|
||
def __getitem__(self, key: int) -> np.ndarray: | ||
if not hasattr(self, "array"): | ||
raise KeyError | ||
return self.array[:, self.indices[key]] | ||
|
||
def __len__(self) -> int: | ||
return len(self.indices) | ||
|
||
def __iter__(self): | ||
return iter(self.indices) | ||
|
||
def __contains__(self, key: int) -> bool: | ||
return key in self.indices | ||
|
||
def add(self, indices: List[int], array: np.ndarray) -> None: | ||
if not hasattr(self, "array"): | ||
self.array = np.empty((array.shape[0], self.block_size), dtype=np.float32) | ||
|
||
if array.shape[0] != self.array.shape[0]: | ||
raise ValueError( | ||
f"Wrong number of rows in array ({array.shape[0]} should be {self.array.shape[0]})" | ||
) | ||
if len(array.shape) != 2: | ||
raise ValueError( | ||
f"`array` must be a numpy array with two dimensions (got {len(array.shape)})" | ||
) | ||
if len(indices) != array.shape[1]: | ||
raise ValueError( | ||
f"`indices` has different length than `array` ({len(indices)} vs. {array.shape[1]})" | ||
) | ||
|
||
if (total_columns := self.next_index + array.shape[1]) > self.array.shape[1]: | ||
extra_blocks = math.ceil((total_columns - self.array.shape[1]) / self.block_size) | ||
self.array = np.hstack( | ||
(self.array, np.empty((self.array.shape[0], self.block_size * extra_blocks))) | ||
) | ||
|
||
# Would be faster with numpy bool arrays | ||
for enum_index, data_obj_index in enumerate(indices): | ||
if data_obj_index not in self.indices: | ||
self.indices[data_obj_index] = self.next_index | ||
self.array[:, self.next_index] = array[:, enum_index] | ||
self.next_index += 1 | ||
|
||
def reset(self) -> None: | ||
self.indices = dict() | ||
self.next_index = 0 | ||
delattr(self, "array") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
import numpy as np | ||
import pytest | ||
|
||
from bw2calc.result_cache import ResultCache | ||
|
||
|
||
def test_first_use(): | ||
rc = ResultCache() | ||
assert not hasattr(rc, "array") | ||
rc.add([5], np.arange(5).reshape((-1, 1))) | ||
|
||
assert rc.array.shape == (5, 100) | ||
assert np.allclose(rc.array[:, 0], np.arange(5)) | ||
assert rc.indices[5] == 0 | ||
|
||
|
||
def test_missing(): | ||
rc = ResultCache() | ||
rc.add([5], np.arange(5).reshape((-1, 1))) | ||
|
||
with pytest.raises(KeyError): | ||
rc[10] | ||
|
||
|
||
def test_missing_before_first_use(): | ||
rc = ResultCache() | ||
|
||
with pytest.raises(KeyError): | ||
rc[10] | ||
|
||
|
||
def test_getitem(): | ||
rc = ResultCache() | ||
rc.add([5], np.arange(5).reshape((-1, 1))) | ||
|
||
assert np.allclose(rc[5], np.arange(5)) | ||
|
||
|
||
def test_contains(): | ||
rc = ResultCache() | ||
rc.add([5], np.arange(5).reshape((-1, 1))) | ||
|
||
assert 5 in rc | ||
|
||
|
||
def test_add_errors(): | ||
rc = ResultCache() | ||
rc.add([5], np.arange(5).reshape((-1, 1))) | ||
|
||
with pytest.raises(ValueError): | ||
rc.add([5], np.arange(10).reshape((-1, 1))) | ||
with pytest.raises(ValueError): | ||
rc.add([5], np.arange(5).reshape((-1, 1, 1))) | ||
with pytest.raises(ValueError): | ||
rc.add([5, 2], np.arange(5).reshape((-1, 1))) | ||
|
||
|
||
def test_add_2d(): | ||
rc = ResultCache() | ||
rc.add([5], np.arange(5).reshape((-1, 1))) | ||
rc.add([7, 10], np.arange(5, 15).reshape((5, 2))) | ||
|
||
assert rc.array.shape == (5, 100) | ||
assert np.allclose(rc.array[:, 1], [5, 7, 9, 11, 13]) | ||
assert np.allclose(rc.array[:, 2], [6, 8, 10, 12, 14]) | ||
assert rc.indices[5] == 0 | ||
assert rc.indices[7] == 1 | ||
assert rc.indices[10] == 2 | ||
assert np.allclose(rc[7], [5, 7, 9, 11, 13]) | ||
assert np.allclose(rc[10], [6, 8, 10, 12, 14]) | ||
|
||
|
||
def test_dont_overwrite_existing(): | ||
rc = ResultCache() | ||
rc.add([5], np.arange(5).reshape((-1, 1))) | ||
rc.add([5, 10], np.arange(5, 15).reshape((5, 2))) | ||
|
||
assert rc.array.shape == (5, 100) | ||
assert np.allclose(rc.array[:, 0], np.arange(5)) | ||
assert np.allclose(rc.array[:, 1], [6, 8, 10, 12, 14]) | ||
assert rc.indices[5] == 0 | ||
assert rc.indices[10] == 1 | ||
assert np.allclose(rc[5], np.arange(5)) | ||
assert np.allclose(rc[10], [6, 8, 10, 12, 14]) | ||
|
||
|
||
def test_expand(): | ||
rc = ResultCache(10) | ||
rc.add(list(range(25)), np.arange(100).reshape((4, -1))) | ||
|
||
assert rc.array.shape == (4, 30) | ||
assert np.allclose(rc.array[0, :25], range(25)) | ||
assert np.allclose(rc.array[:, 0], [0, 25, 50, 75]) | ||
|
||
|
||
def test_reset(): | ||
rc = ResultCache(10) | ||
rc.add(list(range(25)), np.arange(100).reshape((4, -1))) | ||
rc.reset() | ||
|
||
assert not hasattr(rc, "array") | ||
assert rc.indices == {} | ||
assert rc.next_index == 0 |