Skip to content

Commit

Permalink
Add caching LCA class
Browse files Browse the repository at this point in the history
  • Loading branch information
cmutel committed Oct 2, 2024
1 parent c33a639 commit 116cbb8
Show file tree
Hide file tree
Showing 5 changed files with 213 additions and 2 deletions.
8 changes: 6 additions & 2 deletions bw2calc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# flake8: noqa
__all__ = [
"CachingLCA",
"DenseLCA",
"LCA",
"LeastSquaresLCA",
Expand Down Expand Up @@ -31,6 +32,8 @@
Installing it could give you much faster calculations.
"""

PYPARDISO, UMFPACK = False, False

try:
from pypardiso import factorized, spsolve

Expand All @@ -41,14 +44,14 @@
if pltf in ARM:
try:
import scikits.umfpack

UMFPACK = True
except ImportError:
warnings.warn(UMFPACK_WARNING)
elif pltf in AMD_INTEL:
warnings.warn(PYPARDISO_WARNING)

from scipy.sparse.linalg import factorized, spsolve

PYPARDISO = False
try:
from presamples import PackagesDataLoader
except ImportError:
Expand All @@ -63,6 +66,7 @@
prepare_lca_inputs = get_activity = None


from .caching_lca import CachingLCA
from .dense_lca import DenseLCA
from .iterative_lca import IterativeLCA
from .lca import LCA
Expand Down
41 changes: 41 additions & 0 deletions bw2calc/caching_lca.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from scipy import sparse

from .lca import LCA
from .result_cache import ResultCache


class CachingLCA(LCA):
"""Custom class which caches supply vectors.
Cache resets upon iteration. If you do weird stuff outside of iteration you should probably
use the regular LCA class."""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.cache = ResultCache()

def __next__(self) -> None:
self.cache.reset()
super().__next__(self)

def lci_calculation(self) -> None:
"""The actual LCI calculation.
Separated from ``lci`` to be reusable in cases where the matrices are already built, e.g.
``redo_lci`` and Monte Carlo classes.
"""
if hasattr(self, "cache") and len(self.demand) == 1:
key, value = list(self.demand.items())[0]
try:
self.supply_array = self.cache[key] * value
except KeyError:
self.supply_array = self.solve_linear_system()
self.cache.add(key, self.supply_array.reshape((-1, 1)) / value)
else:
self.supply_array = self.solve_linear_system()
# Turn 1-d array into diagonal matrix
count = len(self.dicts.activity)
self.inventory = self.biosphere_matrix @ sparse.spdiags(
[self.supply_array], [0], count, count
)
62 changes: 62 additions & 0 deletions bw2calc/result_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import math
from collections.abc import Mapping
from typing import List

import numpy as np


class ResultCache(Mapping):
def __init__(self, block_size: int = 100):
"""This class allows supply vector results to be cached."""
self.next_index = 0
self.block_size = block_size
self.indices = dict()

def __getitem__(self, key: int) -> np.ndarray:
if not hasattr(self, "array"):
raise KeyError
return self.array[:, self.indices[key]]

def __len__(self) -> int:
return len(self.indices)

def __iter__(self):
return iter(self.indices)

def __contains__(self, key: int) -> bool:
return key in self.indices

def add(self, indices: List[int], array: np.ndarray) -> None:
if not hasattr(self, "array"):
self.array = np.empty((array.shape[0], self.block_size), dtype=np.float32)

if array.shape[0] != self.array.shape[0]:
raise ValueError(
f"Wrong number of rows in array ({array.shape[0]} should be {self.array.shape[0]})"
)
if len(array.shape) != 2:
raise ValueError(
f"`array` must be a numpy array with two dimensions (got {len(array.shape)})"
)
if len(indices) != array.shape[1]:
raise ValueError(
f"`indices` has different length than `array` ({len(indices)} vs. {array.shape[1]})"
)

if (total_columns := self.next_index + array.shape[1]) > self.array.shape[1]:
extra_blocks = math.ceil((total_columns - self.array.shape[1]) / self.block_size)
self.array = np.hstack(
(self.array, np.empty((self.array.shape[0], self.block_size * extra_blocks)))
)

# Would be faster with numpy bool arrays
for enum_index, data_obj_index in enumerate(indices):
if data_obj_index not in self.indices:
self.indices[data_obj_index] = self.next_index
self.array[:, self.next_index] = array[:, enum_index]
self.next_index += 1

def reset(self) -> None:
self.indices = dict()
self.next_index = 0
delattr(self, "array")
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,4 @@ include_trailing_comma = true
force_grid_wrap = 0
use_parentheses = true
ensure_newline_before_comments = true
skip = ["bw2calc/__init__.py"]
103 changes: 103 additions & 0 deletions tests/test_result_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import numpy as np
import pytest

from bw2calc.result_cache import ResultCache


def test_first_use():
rc = ResultCache()
assert not hasattr(rc, "array")
rc.add([5], np.arange(5).reshape((-1, 1)))

assert rc.array.shape == (5, 100)
assert np.allclose(rc.array[:, 0], np.arange(5))
assert rc.indices[5] == 0


def test_missing():
rc = ResultCache()
rc.add([5], np.arange(5).reshape((-1, 1)))

with pytest.raises(KeyError):
rc[10]


def test_missing_before_first_use():
rc = ResultCache()

with pytest.raises(KeyError):
rc[10]


def test_getitem():
rc = ResultCache()
rc.add([5], np.arange(5).reshape((-1, 1)))

assert np.allclose(rc[5], np.arange(5))


def test_contains():
rc = ResultCache()
rc.add([5], np.arange(5).reshape((-1, 1)))

assert 5 in rc


def test_add_errors():
rc = ResultCache()
rc.add([5], np.arange(5).reshape((-1, 1)))

with pytest.raises(ValueError):
rc.add([5], np.arange(10).reshape((-1, 1)))
with pytest.raises(ValueError):
rc.add([5], np.arange(5).reshape((-1, 1, 1)))
with pytest.raises(ValueError):
rc.add([5, 2], np.arange(5).reshape((-1, 1)))


def test_add_2d():
rc = ResultCache()
rc.add([5], np.arange(5).reshape((-1, 1)))
rc.add([7, 10], np.arange(5, 15).reshape((5, 2)))

assert rc.array.shape == (5, 100)
assert np.allclose(rc.array[:, 1], [5, 7, 9, 11, 13])
assert np.allclose(rc.array[:, 2], [6, 8, 10, 12, 14])
assert rc.indices[5] == 0
assert rc.indices[7] == 1
assert rc.indices[10] == 2
assert np.allclose(rc[7], [5, 7, 9, 11, 13])
assert np.allclose(rc[10], [6, 8, 10, 12, 14])


def test_dont_overwrite_existing():
rc = ResultCache()
rc.add([5], np.arange(5).reshape((-1, 1)))
rc.add([5, 10], np.arange(5, 15).reshape((5, 2)))

assert rc.array.shape == (5, 100)
assert np.allclose(rc.array[:, 0], np.arange(5))
assert np.allclose(rc.array[:, 1], [6, 8, 10, 12, 14])
assert rc.indices[5] == 0
assert rc.indices[10] == 1
assert np.allclose(rc[5], np.arange(5))
assert np.allclose(rc[10], [6, 8, 10, 12, 14])


def test_expand():
rc = ResultCache(10)
rc.add(list(range(25)), np.arange(100).reshape((4, -1)))

assert rc.array.shape == (4, 30)
assert np.allclose(rc.array[0, :25], range(25))
assert np.allclose(rc.array[:, 0], [0, 25, 50, 75])


def test_reset():
rc = ResultCache(10)
rc.add(list(range(25)), np.arange(100).reshape((4, -1)))
rc.reset()

assert not hasattr(rc, "array")
assert rc.indices == {}
assert rc.next_index == 0

0 comments on commit 116cbb8

Please sign in to comment.