Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
priyanshuone6 committed Nov 8, 2021
1 parent 53a8e59 commit 58ce6cb
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 87 deletions.
7 changes: 3 additions & 4 deletions pybamm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,8 @@ def version(formatted=False):
EvaluatorPython,
)

if have_jax():
from .expression_tree.operations.evaluate_python import EvaluatorJax
from .expression_tree.operations.evaluate_python import JaxCooMatrix
from .expression_tree.operations.evaluate_python import EvaluatorJax
from .expression_tree.operations.evaluate_python import JaxCooMatrix

from .expression_tree.operations.jacobian import Jacobian
from .expression_tree.operations.convert_to_casadi import CasadiConverter
Expand Down Expand Up @@ -223,8 +222,8 @@ def version(formatted=False):
from .solvers.scikits_ode_solver import ScikitsOdeSolver, have_scikits_odes
from .solvers.scipy_solver import ScipySolver

from .solvers.jax_solver import JaxSolver
if have_jax():
from .solvers.jax_solver import JaxSolver
from .solvers.jax_bdf_solver import jax_bdf_integrate

from .solvers.idaklu_solver import IDAKLUSolver, have_idaklu
Expand Down
159 changes: 85 additions & 74 deletions pybamm/expression_tree/operations/evaluate_python.py
Original file line number Diff line number Diff line change
@@ -1,104 +1,110 @@
#
# Write a symbol to python
#
import pybamm
import numbers
from collections import OrderedDict

import numpy as np
import scipy.sparse
from collections import OrderedDict
import numbers

import pybamm

if pybamm.have_jax():
import jax
from jax.config import config

config.update("jax_enable_x64", True)

class JaxCooMatrix:
"""
A sparse matrix in COO format, with internal arrays using jax device arrays

This matrix only has two operations supported, a multiply with a scalar, and a
dot product with a dense vector. It can also be converted to a dense 2D jax
device array
class JaxCooMatrix:
"""
A sparse matrix in COO format, with internal arrays using jax device arrays
This matrix only has two operations supported, a multiply with a scalar, and a
dot product with a dense vector. It can also be converted to a dense 2D jax
device array
Parameters
----------
Parameters
----------
row: arraylike
1D array holding row indices of non-zero entries
col: arraylike
1D array holding col indices of non-zero entries
data: arraylike
1D array holding non-zero entries
shape: 2-element tuple (x, y)
where x is the number of rows, and y the number of columns of the matrix
"""
row: arraylike
1D array holding row indices of non-zero entries
col: arraylike
1D array holding col indices of non-zero entries
data: arraylike
1D array holding non-zero entries
shape: 2-element tuple (x, y)
where x is the number of rows, and y the number of columns of the matrix
"""

def __init__(self, row, col, data, shape):
self.row = jax.numpy.array(row)
self.col = jax.numpy.array(col)
self.data = jax.numpy.array(data)
self.shape = shape
self.nnz = len(self.data)

def toarray(self):
"""convert sparse matrix to a dense 2D array"""
result = jax.numpy.zeros(self.shape, dtype=self.data.dtype)
return result.at[self.row, self.col].add(self.data)

def dot_product(self, b):
"""
dot product of matrix with a dense column vector b
Parameters
----------
b: jax device array
must have shape (n, 1)
"""
# assume b is a column vector
result = jax.numpy.zeros((self.shape[0], 1), dtype=b.dtype)
return result.at[self.row].add(self.data.reshape(-1, 1) * b[self.col])

def scalar_multiply(self, b):
"""
multiply of matrix with a scalar b
Parameters
----------
b: Number or 1 element jax device array
scalar value to multiply
"""
# assume b is a scalar or ndarray with 1 element
return JaxCooMatrix(
self.row, self.col, (self.data * b).reshape(-1), self.shape
def __init__(self, row, col, data, shape):
if not pybamm.have_jax():
raise ModuleNotFoundError(
"Jax is not installed, please see https://pybamm.readthedocs.io/en/latest/install/GNU-linux.html#optional-jaxsolver" # noqa: E501
)

def multiply(self, b):
"""
general matrix multiply not supported
"""
raise NotImplementedError
self.row = jax.numpy.array(row)
self.col = jax.numpy.array(col)
self.data = jax.numpy.array(data)
self.shape = shape
self.nnz = len(self.data)

def toarray(self):
"""convert sparse matrix to a dense 2D array"""
result = jax.numpy.zeros(self.shape, dtype=self.data.dtype)
return result.at[self.row, self.col].add(self.data)

def __matmul__(self, b):
"""see self.dot_product"""
return self.dot_product(b)
def dot_product(self, b):
"""
dot product of matrix with a dense column vector b
Parameters
----------
b: jax device array
must have shape (n, 1)
"""
# assume b is a column vector
result = jax.numpy.zeros((self.shape[0], 1), dtype=b.dtype)
return result.at[self.row].add(self.data.reshape(-1, 1) * b[self.col])

def create_jax_coo_matrix(value):
def scalar_multiply(self, b):
"""
Creates a JaxCooMatrix from a scipy.sparse matrix
multiply of matrix with a scalar b
Parameters
----------
b: Number or 1 element jax device array
scalar value to multiply
"""
# assume b is a scalar or ndarray with 1 element
return JaxCooMatrix(self.row, self.col, (self.data * b).reshape(-1), self.shape)

value: scipy.sparse matrix
the sparse matrix to be converted
def multiply(self, b):
"""
general matrix multiply not supported
"""
scipy_coo = value.tocoo()
row = jax.numpy.asarray(scipy_coo.row)
col = jax.numpy.asarray(scipy_coo.col)
data = jax.numpy.asarray(scipy_coo.data)
return JaxCooMatrix(row, col, data, value.shape)
raise NotImplementedError

def __matmul__(self, b):
"""see self.dot_product"""
return self.dot_product(b)


def create_jax_coo_matrix(value):
"""
Creates a JaxCooMatrix from a scipy.sparse matrix
Parameters
----------
value: scipy.sparse matrix
the sparse matrix to be converted
"""
scipy_coo = value.tocoo()
row = jax.numpy.asarray(scipy_coo.row)
col = jax.numpy.asarray(scipy_coo.col)
data = jax.numpy.asarray(scipy_coo.data)
return JaxCooMatrix(row, col, data, value.shape)


def id_to_python_variable(symbol_id, constant=False):
Expand Down Expand Up @@ -539,6 +545,11 @@ class EvaluatorJax:
"""

def __init__(self, symbol):
if not pybamm.have_jax():
raise ModuleNotFoundError(
"Jax is not installed, please see https://pybamm.readthedocs.io/en/latest/install/GNU-linux.html#optional-jaxsolver" # noqa: E501
)

constants, python_str = pybamm.to_python(symbol, debug=False, output_jax=True)

# replace numpy function calls to jax numpy calls
Expand Down
15 changes: 11 additions & 4 deletions pybamm/solvers/jax_solver.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
#
# Solver class using Scipy's adaptive time stepper
#
import numpy as onp

import pybamm

import jax
from jax.experimental.ode import odeint
import jax.numpy as jnp
import numpy as onp
if pybamm.have_jax():
import jax
import jax.numpy as jnp
from jax.experimental.ode import odeint


class JaxSolver(pybamm.BaseSolver):
Expand Down Expand Up @@ -56,6 +58,11 @@ def __init__(
extrap_tol=0,
extra_options=None,
):
if not pybamm.have_jax():
raise ModuleNotFoundError(
"Jax is not installed, please see https://pybamm.readthedocs.io/en/latest/install/GNU-linux.html#optional-jaxsolver" # noqa: E501
)

# note: bdf solver itself calculates consistent initial conditions so can set
# root_method to none, allow user to override this behavior
super().__init__(
Expand Down
14 changes: 9 additions & 5 deletions pybamm/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,21 @@
# (see https://github.com/pints-team/pints)
#
import importlib
import numpy as np
import numbers
import os
import timeit
import pathlib
import pickle
import pybamm
import numbers
import subprocess
import sys
import timeit
import warnings
from collections import defaultdict
from platform import system

import numpy as np

import pybamm


def root_dir():
"""return the root directory of the PyBaMM install directory"""
Expand Down Expand Up @@ -347,7 +349,9 @@ def have_jax():

def install_jax():
"""Install jax, jaxlib"""
if system() != "Windows":
if system() == "Windows":
raise NotImplementedError("Jax is not available on Windows")
else:
subprocess.check_call(
[sys.executable, "-m", "pip", "install", "jax==0.2.12", "jaxlib==0.1.70"]
)

0 comments on commit 58ce6cb

Please sign in to comment.