Skip to content

Commit

Permalink
Merge pull request #919 from pybamm-team/issue-917-parameter
Browse files Browse the repository at this point in the history
#917 fix function that returns float
  • Loading branch information
valentinsulzer authored Mar 26, 2020
2 parents dca1f30 + 799e09b commit 998b2c6
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

## Bug fixes

- Fixed bug raised if function returns a scalar ([#919](https://github.com/pybamm-team/PyBaMM/pull/919))
- Updated Getting started notebook 2 ([#903](https://github.com/pybamm-team/PyBaMM/pull/903))
- Reformatted external circuit submodels ([#879](https://github.com/pybamm-team/PyBaMM/pull/879))
- Some bug fixes to generalize specifying models that aren't battery models, see [#846](https://github.com/pybamm-team/PyBaMM/issues/846)
Expand Down
3 changes: 3 additions & 0 deletions pybamm/parameters/parameter_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,9 @@ def _process_symbol(self, symbol):
# return differentiated function
new_diff_variable = self.process_symbol(symbol.diff_variable)
function_out = function.diff(new_diff_variable)
# Convert possible float output to a pybamm scalar
if isinstance(function_out, numbers.Number):
return pybamm.Scalar(function_out)
# Process again just to be sure
return self.process_symbol(function_out)

Expand Down
6 changes: 6 additions & 0 deletions tests/unit/test_parameters/test_parameter_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@ def test_process_function_parameter(self):
"a": 3,
"func": pybamm.load_function("process_symbol_test_function.py"),
"const": 254,
"float_func": lambda x: 42,
}
)
a = pybamm.InputParameter("a")
Expand All @@ -320,6 +321,11 @@ def test_process_function_parameter(self):
processed_diff_func = parameter_values.process_symbol(diff_func)
self.assertEqual(processed_diff_func.evaluate(u={"a": 3}), 123)

# function parameter that returns a python float
func = pybamm.FunctionParameter("float_func", a)
processed_func = parameter_values.process_symbol(func)
self.assertEqual(processed_func.evaluate(), 42)

# function itself as input (different to the variable being an input)
parameter_values = pybamm.ParameterValues({"func": "[input]"})
a = pybamm.Scalar(3)
Expand Down

0 comments on commit 998b2c6

Please sign in to comment.