Skip to content

Commit

Permalink
Save objects imported with an alias and top level modules by referenc…
Browse files Browse the repository at this point in the history
…e in `dump_session(byref=TRUE)`

Currently, `dump_session(byref=True)` misses some imported objects. For example:
- If the session had a statement `import numpy as np`, it may find a reference to the `numpy` named
as `np` in some internal module listed in `sys.resources`.  But if the module was imported with a
non-canonical name, like `import numpy as nump`, it won't find it at all. Mapping the objects by id
in `modmap` solves the issue.  Note that just types of objects usually imported under an alias must
be looked up by id, otherwise common objects like singletons may be wrongly attributed to a module,
and such reference in the module could change to a different object depending on its initialization
and state.
- If a object in the global scope is a top level module, like `math`, again `save_session` may find
a reference to it in another module and it works. But if this module isn't referenced anywhere else,
it won't be found because the function only looks for objects inside the `sys.resources` modules and
not for the modules themselves.

This commit introduces two new attributes to session modules saved by reference:
- `__dill_imported_as`: a list with (module name, object name, object alias in session)
- `__dill_imported_top_level`: a list with (module name, module alias in session)

I did it this way for forwards (complete) and backwards (partial) compatibility.

Oh, and I got rid of that nasty `exec()` call in `_restore_modules()`! ;)
  • Loading branch information
leogama committed Apr 25, 2022
1 parent 44a9e54 commit d430c2a
Showing 1 changed file with 46 additions and 21 deletions.
67 changes: 46 additions & 21 deletions dill/_dill.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,50 +397,75 @@ def loads(str, ignore=None, **kwds):
### End: Shorthands ###

### Pickle the Interpreter Session
SESSION_IMPORTED_AS_TYPES = (ModuleType, ClassType, TypeType, Exception,
FunctionType, MethodType, BuiltinMethodType)

def _module_map():
"""get map of imported modules"""
from collections import defaultdict
modmap = defaultdict(list)
from collections import defaultdict, namedtuple
modmap = namedtuple('Modmap', ['by_module', 'by_name', 'top_level'])
modmap = modmap(defaultdict(list), defaultdict(list), {})
items = 'items' if PY3 else 'iteritems'
for name, module in getattr(sys.modules, items)():
for modname, module in getattr(sys.modules, items)():
if module is None:
continue
for objname, obj in module.__dict__.items():
modmap[objname].append((obj, name))
return modmap

def _lookup_module(modmap, name, obj, main_module): #FIXME: needs work
"""lookup name if module is imported"""
for modobj, modname in modmap[name]:
if '.' not in modname:
modmap.top_level[id(module)] = modname
for objname, modobj in module.__dict__.items():
modmap.by_name[objname].append((modobj, modname))
modmap.by_id[id(modobj)].append((modobj, objname, modname))
return modmap, modmap_byid

def _lookup_module(modmap, name, obj, main_module):
"""lookup name or id of obj if module is imported"""
for modobj, modname in modmap.by_name[name]:
if modobj is obj and sys.modules[modname] is not main_module:
return modname
return modname, name
if isinstance(obj, SESSION_IMPORTED_AS_TYPES):
for modobj, objname, modname in modmap.by_id[id(obj)]:
if sys.modules[modname] is not main_module:
return modname, objname
return None, None

def _stash_modules(main_module):
modmap = _module_map()
imported = []
imported_as = []
imported_top_level = [] # backwards compatibility
original = {}
items = 'items' if PY3 else 'iteritems'
for name, obj in getattr(main_module.__dict__, items)():
source_module = _lookup_module(modmap, name, obj, main_module)
source_module, objname = _lookup_module(modmap, name, obj, main_module)
if source_module:
imported.append((source_module, name))
if objname == name:
imported.append((source_module, name))
else:
imported_as.append((source_module, objname, name))
else:
original[name] = obj
try:
imported_top_level.append((modmap.top_level[id(obj)], name))
except KeyError:
original[name] = obj
if len(imported):
import types
newmod = types.ModuleType(main_module.__name__)
newmod = ModuleType(main_module.__name__)
newmod.__dict__.update(original)
newmod.__dill_imported = imported
newmod.__dill_imported_as = imported_as
newmod.__dill_imported_top_level = imported_top_level
return newmod
else:
return main_module

def _restore_modules(main_module):
if '__dill_imported' not in main_module.__dict__:
return
imports = main_module.__dict__.pop('__dill_imported')
for module, name in imports:
exec("from %s import %s" % (module, name), main_module.__dict__)
try:
for modname, name in main_module.__dict__.pop('__dill_imported'):
main_module.__dict__[name] = __import__(modname, None, None, [name]).__dict__[name]
for modname, objname, name in main_module.__dict__.pop('__dill_imported_as'):
main_module.__dict__[name] = __import__(modname, None, None, [objname]).__dict__[objname]
for modname, name in main_module.__dict__.pop('__dill_imported_top_level'):
main_module.__dict__[name] = __import__(modname)
except KeyError:
pass

#NOTE: 06/03/15 renamed main_module to main
def dump_session(filename='/tmp/session.pkl', main=None, byref=False, **kwds):
Expand Down

0 comments on commit d430c2a

Please sign in to comment.