Skip to content

Commit

Permalink
Merge pull request #1596 from priyanshuone6/latexify_notebook
Browse files Browse the repository at this point in the history
Add latexify notebook
  • Loading branch information
valentinsulzer authored Aug 18, 2021
2 parents 049a9e5 + b8ea482 commit 6f4a152
Show file tree
Hide file tree
Showing 24 changed files with 614 additions and 36 deletions.
1 change: 1 addition & 0 deletions .github/workflows/test_on_push.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ jobs:
- name: Install Linux system dependencies
if: matrix.os == 'ubuntu-latest'
run: |
sudo apt-get update
sudo apt install gfortran gcc libopenblas-dev graphviz
sudo apt install texlive-full
Expand Down
554 changes: 554 additions & 0 deletions examples/notebooks/models/latexify.ipynb

Large diffs are not rendered by default.

Binary file added examples/notebooks/models/spm_equations.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion pybamm/expression_tree/binary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def _sympy_operator(self, left, right):
def to_equation(self):
"""Convert the node and its subtree into a SymPy equation."""
if self.print_name is not None:
return sympy.symbols(self.print_name)
return sympy.Symbol(self.print_name)
else:
child1, child2 = self.children
eq1 = child1.to_equation()
Expand Down
2 changes: 1 addition & 1 deletion pybamm/expression_tree/concatenations.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def _sympy_operator(self, *children):
self.concat_latex = tuple(map(sympy.latex, children))

if self.print_name is not None:
return sympy.symbols(self.print_name)
return sympy.Symbol(self.print_name)
else:
concat_str = r"\\".join(self.concat_latex)
concat_sym = sympy.Symbol(r"\begin{cases}" + concat_str + r"\end{cases}")
Expand Down
2 changes: 1 addition & 1 deletion pybamm/expression_tree/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def _sympy_operator(self, child):
def to_equation(self):
"""Convert the node and its subtree into a SymPy equation."""
if self.print_name is not None:
return sympy.symbols(self.print_name)
return sympy.Symbol(self.print_name)
else:
eq_list = []
for child in self.children:
Expand Down
8 changes: 6 additions & 2 deletions pybamm/expression_tree/independent_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ def _jac(self, variable):
def to_equation(self):
"""Convert the node and its subtree into a SymPy equation."""
if self.print_name is not None:
return sympy.symbols(self.print_name)
return sympy.Symbol(self.print_name)
else:
return sympy.symbols(self.name)
return sympy.Symbol(self.name)


class Time(IndependentVariable):
Expand Down Expand Up @@ -72,6 +72,10 @@ def _evaluate_for_shape(self):
"""
return 0

def to_equation(self):
"""Convert the node and its subtree into a SymPy equation."""
return sympy.Symbol("t")


class SpatialVariable(IndependentVariable):
"""
Expand Down
13 changes: 7 additions & 6 deletions pybamm/expression_tree/operations/latexify.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def _get_concat_displays(self, node):
r"\begin{cases}" + r" \\ ".join(concat_geo) + r"\end{cases}"
)
concat_eqn = sympy.Eq(
sympy.symbols(node.print_name),
sympy.Symbol(node.print_name),
sympy.Symbol(concat_sym),
evaluate=False,
)
Expand Down Expand Up @@ -200,8 +200,9 @@ def _get_param_var(self, node):

# Add spaces between words
node_copy_eqn = node_copy.to_equation()
# Typical current [A] --> \text{Typical current [A]}
if re.search(r"(^[0-9a-zA-Z-\s.-\[\]()]*$)", str(node_copy_eqn)):
node_copy_latex = r"\textit{" + str(node_copy_eqn) + "}"
node_copy_latex = r"\text{" + str(node_copy_eqn) + "}"
else:
node_copy_latex = sympy.latex(node_copy_eqn)

Expand Down Expand Up @@ -235,13 +236,13 @@ def latexify(self):
# Add model name to the list
eqn_list.append(
sympy.Symbol(
r"\underline{\textbf{\large{" + self.model.name + " Equations}}}"
r"\large{\underline{\textbf{" + self.model.name + " Equations}}}"
)
)

for eqn_type in ["rhs", "algebraic"]:
for var, eqn in getattr(self.model, eqn_type).items():
var_symbol = sympy.symbols(var.print_name)
var_symbol = sympy.Symbol(var.print_name)

# Add equation name to the list
eqn_list.append(sympy.Symbol(r"\\ \textbf{" + str(var) + "}"))
Expand Down Expand Up @@ -314,7 +315,7 @@ def latexify(self):
# Add voltage expression to the list
if "Terminal voltage [V]" in self.model.variables:
voltage = self.model.variables["Terminal voltage [V]"].to_equation()
voltage_eqn = sympy.Eq(sympy.symbols("V"), voltage, evaluate=False)
voltage_eqn = sympy.Eq(sympy.Symbol("V"), voltage, evaluate=False)
# Add terminal voltage to the list
eqn_list.append(sympy.Symbol(r"\\ \textbf{Terminal voltage [V]}"))
eqn_list.extend([voltage_eqn])
Expand Down Expand Up @@ -368,7 +369,7 @@ def latexify(self):
# When equations are too huge, set output resolution to default
except RuntimeError: # pragma: no cover
warnings.warn(
"RuntimeError: Setting the output resolution to default"
"RuntimeError - Setting the output resolution to default"
)
return sympy.preview(
eqn_new_line,
Expand Down
4 changes: 2 additions & 2 deletions pybamm/expression_tree/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def is_constant(self):
def to_equation(self):
"""Convert the node and its subtree into a SymPy equation."""
if self.print_name is not None:
return sympy.symbols(self.print_name)
return sympy.Symbol(self.print_name)
else:
return sympy.Symbol(self.name)

Expand Down Expand Up @@ -237,6 +237,6 @@ def _evaluate_for_shape(self):
def to_equation(self):
"""Convert the node and its subtree into a SymPy equation."""
if self.print_name is not None:
return sympy.symbols(self.print_name)
return sympy.Symbol(self.print_name)
else:
return sympy.Symbol(self.name)
2 changes: 1 addition & 1 deletion pybamm/expression_tree/printing/print_name.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def prettify_print_name(name):
"""Prettify print_name using regex"""

# Skip prettify_print_name() for cases like `new_copy()`
if "{" in name:
if "{" in name or "\\" in name:
return name

# Return print_name if name exists in the dictionary
Expand Down
6 changes: 5 additions & 1 deletion pybamm/expression_tree/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Scalar class
#
import numpy as np
import sympy

import pybamm

Expand Down Expand Up @@ -70,4 +71,7 @@ def is_constant(self):

def to_equation(self):
"""Returns the value returned by the node when evaluated."""
return self.value
if self.print_name is not None:
return sympy.Symbol(self.print_name)
else:
return self.value
2 changes: 1 addition & 1 deletion pybamm/expression_tree/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -980,4 +980,4 @@ def print_name(self, name):
self._print_name = prettify_print_name(name)

def to_equation(self):
return sympy.symbols(str(self.name))
return sympy.Symbol(str(self.name))
4 changes: 2 additions & 2 deletions pybamm/expression_tree/unary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def _sympy_operator(self, child):
def to_equation(self):
"""Convert the node and its subtree into a SymPy equation."""
if self.print_name is not None:
return sympy.symbols(self.print_name)
return sympy.Symbol(self.print_name)
else:
eq1 = self.child.to_equation()
return self._sympy_operator(eq1)
Expand Down Expand Up @@ -643,7 +643,7 @@ def _evaluates_on_edges(self, dimension):

def _sympy_operator(self, child):
"""Override :meth:`pybamm.UnaryOperator._sympy_operator`"""
return sympy.Integral(child, sympy.symbols("xn"))
return sympy.Integral(child, sympy.Symbol("xn"))


class BaseIndefiniteIntegral(Integral):
Expand Down
2 changes: 1 addition & 1 deletion pybamm/expression_tree/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def _evaluate_for_shape(self):
def to_equation(self):
"""Convert the node and its subtree into a SymPy equation."""
if self.print_name is not None:
return sympy.symbols(self.print_name)
return sympy.Symbol(self.print_name)
else:
return self.name

Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_expression_tree/test_binary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,7 +681,7 @@ def test_inner_simplifications(self):
def test_to_equation(self):
# Test print_name
pybamm.Addition.print_name = "test"
self.assertEqual(pybamm.Addition(1, 2).to_equation(), sympy.symbols("test"))
self.assertEqual(pybamm.Addition(1, 2).to_equation(), sympy.Symbol("test"))

# Test Power
self.assertEqual(pybamm.Power(7, 2).to_equation(), 49)
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_expression_tree/test_concatenations.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,12 +363,12 @@ def test_numpy_concatenation(self):
def test_to_equation(self):
a = pybamm.Symbol("a", domain="test a")
b = pybamm.Symbol("b", domain="test b")
func_symbol = sympy.symbols(r"\begin{cases}a\\b\end{cases}")
func_symbol = sympy.Symbol(r"\begin{cases}a\\b\end{cases}")

# Test print_name
func = pybamm.Concatenation(a, b)
func.print_name = "test"
self.assertEqual(func.to_equation(), sympy.symbols("test"))
self.assertEqual(func.to_equation(), sympy.Symbol("test"))

# Test concat_sym
self.assertEqual(pybamm.Concatenation(a, b).to_equation(), func_symbol)
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_expression_tree/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def test_to_equation(self):
# Test print_name
func = pybamm.Arcsinh(a)
func.print_name = "test"
self.assertEqual(func.to_equation(), sympy.symbols("test"))
self.assertEqual(func.to_equation(), sympy.Symbol("test"))

# Test Arcsinh
self.assertEqual(pybamm.Arcsinh(a).to_equation(), sympy.asinh(a))
Expand Down
7 changes: 5 additions & 2 deletions tests/unit/test_expression_tree/test_independent_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,15 @@ def test_to_equation(self):
# Test print_name
func = pybamm.IndependentVariable("a")
func.print_name = "test"
self.assertEqual(func.to_equation(), sympy.symbols("test"))
self.assertEqual(func.to_equation(), sympy.Symbol("test"))

self.assertEqual(
pybamm.IndependentVariable("a").to_equation(), sympy.symbols("a")
pybamm.IndependentVariable("a").to_equation(), sympy.Symbol("a")
)

# Test time
self.assertEqual(pybamm.t.to_equation(), sympy.Symbol("t"))


if __name__ == "__main__":
print("Add -v for more debug output")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import uuid

import pybamm
from pybamm.expression_tree.operations.latexify import Latexify

model_dfn = pybamm.lithium_ion.DFN()
func_dfn = str(model_dfn.latexify())
Expand All @@ -17,6 +18,9 @@

class TestLatexify(unittest.TestCase):
def test_latexify(self):
# Test docstring
self.assertEqual(pybamm.BaseModel.latexify.__doc__, Latexify.__doc__)

# Test model name
self.assertIn("Single Particle Model with electrolyte Equations", func_spme)

Expand Down
8 changes: 4 additions & 4 deletions tests/unit/test_expression_tree/test_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ def test_to_equation(self):

# Test print_name
func.print_name = "test"
self.assertEqual(func.to_equation(), sympy.symbols("test"))
self.assertEqual(func.to_equation(), sympy.Symbol("test"))

# Test name
self.assertEqual(func1.to_equation(), sympy.symbols("test_name"))
self.assertEqual(func1.to_equation(), sympy.Symbol("test_name"))


class TestFunctionParameter(unittest.TestCase):
Expand Down Expand Up @@ -113,11 +113,11 @@ def test_function_parameter_to_equation(self):

# Test print_name
func.print_name = "test"
self.assertEqual(func.to_equation(), sympy.symbols("test"))
self.assertEqual(func.to_equation(), sympy.Symbol("test"))

# Test name
func1.print_name = None
self.assertEqual(func1.to_equation(), sympy.symbols("func"))
self.assertEqual(func1.to_equation(), sympy.Symbol("func"))


if __name__ == "__main__":
Expand Down
7 changes: 7 additions & 0 deletions tests/unit/test_expression_tree/test_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,15 @@ def test_scalar_id(self):

def test_to_equation(self):
a = pybamm.Scalar(3)
b = pybamm.Scalar(4)

# Test value
self.assertEqual(str(a.to_equation()), "3.0")

# Test print_name
b.print_name = "test"
self.assertEqual(str(b.to_equation()), "test")


if __name__ == "__main__":
print("Add -v for more debug output")
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_expression_tree/test_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ def test_test_shape(self):
(y1 + y2).test_shape()

def test_to_equation(self):
self.assertEqual(pybamm.Symbol("test").to_equation(), sympy.symbols("test"))
self.assertEqual(pybamm.Symbol("test").to_equation(), sympy.Symbol("test"))


class TestIsZero(unittest.TestCase):
Expand Down
10 changes: 5 additions & 5 deletions tests/unit/test_expression_tree/test_unary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,7 +872,7 @@ def test_to_equation(self):

# Test print_name
pybamm.Floor.print_name = "test"
self.assertEqual(pybamm.Floor(-2.5).to_equation(), sympy.symbols("test"))
self.assertEqual(pybamm.Floor(-2.5).to_equation(), sympy.Symbol("test"))

# Test Negate
self.assertEqual(pybamm.Negate(4).to_equation(), -4.0)
Expand All @@ -891,13 +891,13 @@ def test_to_equation(self):

# Test BoundaryValue
self.assertEqual(
pybamm.BoundaryValue(one, "right").to_equation(), sympy.symbols("1")
pybamm.BoundaryValue(one, "right").to_equation(), sympy.Symbol("1")
)
self.assertEqual(
pybamm.BoundaryValue(a, "right").to_equation(), sympy.symbols("a^{surf}")
pybamm.BoundaryValue(a, "right").to_equation(), sympy.Symbol("a^{surf}")
)
self.assertEqual(
pybamm.BoundaryValue(b, "positive tab").to_equation(), sympy.symbols(str(b))
pybamm.BoundaryValue(b, "positive tab").to_equation(), sympy.Symbol(str(b))
)
self.assertEqual(
pybamm.BoundaryValue(c, "left").to_equation(),
Expand All @@ -908,7 +908,7 @@ def test_to_equation(self):
xn = pybamm.SpatialVariable("xn", ["negative electrode"])
self.assertEqual(
pybamm.Integral(d, xn).to_equation(),
sympy.Integral("d", sympy.symbols("xn")),
sympy.Integral("d", sympy.Symbol("xn")),
)


Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_expression_tree/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def test_to_equation(self):
# Test print_name
func = pybamm.Variable("test_string")
func.print_name = "test"
self.assertEqual(func.to_equation(), sympy.symbols("test"))
self.assertEqual(func.to_equation(), sympy.Symbol("test"))

# Test name
self.assertEqual(pybamm.Variable("name").to_equation(), "name")
Expand Down

0 comments on commit 6f4a152

Please sign in to comment.