Skip to content

Commit

Permalink
Remove theanof.set_theano_conf and instead use the config context (#4329
Browse files Browse the repository at this point in the history
)

* Remove theanof.set_theano_conf and instead use the config context properly
  • Loading branch information
michaelosthege authored Dec 13, 2020
1 parent 6f15cbb commit 70fdcf9
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 65 deletions.
2 changes: 2 additions & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
### Maintenance
- Fixed bug whereby partial traces returns after keyboard interrupt during parallel sampling had fewer draws than would've been available [#4318](https://github.com/pymc-devs/pymc3/pull/4318)
- Make `sample_shape` same across all contexts in `draw_values` (see [#4305](https://github.com/pymc-devs/pymc3/pull/4305)).
- Removed `theanof.set_theano_config` because it illegally touched Theano's privates (see [#4329](https://github.com/pymc-devs/pymc3/pull/4329)).


## PyMC3 3.10.0 (7 December 2020)

Expand Down
17 changes: 6 additions & 11 deletions pymc3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,7 @@
from pymc3.exceptions import ImputationWarning
from pymc3.math import flatten_list
from pymc3.memoize import WithMemoization, memoize
from pymc3.theanof import (
floatX,
generator,
gradient,
hessian,
inputvars,
set_theano_conf,
)
from pymc3.theanof import floatX, generator, gradient, hessian, inputvars
from pymc3.util import get_transformed_name, get_var_name
from pymc3.vartypes import continuous_types, discrete_types, isgenerator, typefilter

Expand Down Expand Up @@ -288,15 +281,17 @@ def __new__(cls, name, bases, dct, **kargs): # pylint: disable=unused-argument
def __enter__(self):
self.__class__.context_class.get_contexts().append(self)
# self._theano_config is set in Model.__new__
self._config_context = None
if hasattr(self, "_theano_config"):
self._old_theano_config = set_theano_conf(self._theano_config)
self._config_context = theano.change_flags(**self._theano_config)
self._config_context.__enter__()
return self

def __exit__(self, typ, value, traceback): # pylint: disable=unused-argument
self.__class__.context_class.get_contexts().pop()
# self._theano_config is set in Model.__new__
if hasattr(self, "_old_theano_config"):
set_theano_conf(self._old_theano_config)
if self._config_context:
self._config_context.__exit__(typ, value, traceback)

dct[__enter__.__name__] = __enter__
dct[__exit__.__name__] = __exit__
Expand Down
25 changes: 1 addition & 24 deletions pymc3/tests/test_theanof.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import collections

from itertools import product

import numpy as np
import pytest
import theano
import theano.tensor as tt

from pymc3.theanof import _conversion_map, set_theano_conf, take_along_axis
from pymc3.theanof import _conversion_map, take_along_axis
from pymc3.vartypes import int_types

FLOATX = str(theano.config.floatX)
Expand Down Expand Up @@ -72,27 +70,6 @@ def np_take_along_axis(arr, indices, axis):
return arr[_make_along_axis_idx(arr.shape, indices, _axis)]


class TestSetTheanoConfig:
def test_invalid_key(self):
with pytest.raises(ValueError) as e:
set_theano_conf({"bad_key": True})
e.match("Unknown")

def test_restore_when_bad_key(self):
with theano.configparser.change_flags(compute_test_value="off"):
with pytest.raises(ValueError):
conf = collections.OrderedDict([("compute_test_value", "raise"), ("bad_key", True)])
set_theano_conf(conf)
assert theano.config.compute_test_value == "off"

def test_restore(self):
with theano.configparser.change_flags(compute_test_value="off"):
conf = set_theano_conf({"compute_test_value": "raise"})
assert conf == {"compute_test_value": "off"}
conf = set_theano_conf(conf)
assert conf == {"compute_test_value": "raise"}


class TestTakeAlongAxis:
def setup_class(self):
self.inputs_buffer = dict()
Expand Down
31 changes: 1 addition & 30 deletions pymc3/theanof.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@
import numpy as np
import theano

from theano import scalar
from theano import change_flags, scalar
from theano import tensor as tt
from theano.configparser import change_flags
from theano.gof import Op
from theano.gof.graph import inputs
from theano.sandbox.rng_mrg import MRG_RandomStreams
Expand Down Expand Up @@ -442,34 +441,6 @@ def floatX_array(x):
return floatX(np.array(x))


def set_theano_conf(values):
"""Change the theano configuration and return old values.
This is similar to `theano.configparser.change_flags`, but it
returns the original values in a pickleable form.
"""
variables = {}
unknown = set(values.keys())
for variable in theano.configparser._config_var_list:
if variable.fullname in values:
variables[variable.fullname] = variable
unknown.remove(variable.fullname)
if len(unknown) > 0:
raise ValueError("Unknown theano config settings: %s" % unknown)

old = {}
for name, variable in variables.items():
old_value = variable.__get__(True, None)
try:
variable.__set__(None, values[name])
except Exception:
for key, old_value in old.items():
variables[key].__set__(None, old_value)
raise
old[name] = old_value
return old


def ix_(*args):
"""
Theano np.ix_ analog
Expand Down

0 comments on commit 70fdcf9

Please sign in to comment.