From 67f0b780db61930d2840f5135cd46c636ff9d94a Mon Sep 17 00:00:00 2001 From: Leonardo Gama Date: Mon, 25 Apr 2022 15:37:39 -0300 Subject: [PATCH] Save objects imported with an alias and top level modules by reference in `dump_session(byref=TRUE)` Currently, `dump_session(byref=True)` misses some imported objects. For example: - If the session had a statement `import numpy as np`, it may find a reference to the `numpy` named as `np` in some internal module listed in `sys.resources`. But if the module was imported with a non-canonical name, like `import numpy as nump`, it won't find it at all. Mapping the objects by id in `modmap` solves the issue. Note that just types of objects usually imported under an alias must be looked up by id, otherwise common objects like singletons may be wrongly attributed to a module, and such reference in the module could change to a different object depending on its initialization and state. - If a object in the global scope is a top level module, like `math`, again `save_session` may find a reference to it in another module and it works. But if this module isn't referenced anywhere else, it won't be found because the function only looks for objects inside the `sys.resources` modules and not for the modules themselves. This commit introduces two new attributes to session modules saved by reference: - `__dill_imported_as`: a list with (module name, object name, object alias in session) - `__dill_imported_top_level`: a list with (module name, module alias in session) I did it this way for forwards (complete) and backwards (partial) compatibility. Oh, and I got rid of that nasty `exec()` call in `_restore_modules()`! ;) --- dill/_dill.py | 67 +++++++++++++++++++++++++++++++++++---------------- 1 file changed, 46 insertions(+), 21 deletions(-) diff --git a/dill/_dill.py b/dill/_dill.py index 13b2b246..eccfebba 100644 --- a/dill/_dill.py +++ b/dill/_dill.py @@ -397,50 +397,75 @@ def loads(str, ignore=None, **kwds): ### End: Shorthands ### ### Pickle the Interpreter Session +SESSION_IMPORTED_AS_TYPES = (ModuleType, ClassType, TypeType, Exception, + FunctionType, MethodType, BuiltinMethodType) + def _module_map(): """get map of imported modules""" - from collections import defaultdict - modmap = defaultdict(list) + from collections import defaultdict, namedtuple + modmap = namedtuple('Modmap', ['by_module', 'by_name', 'top_level']) + modmap = modmap(defaultdict(list), defaultdict(list), {}) items = 'items' if PY3 else 'iteritems' - for name, module in getattr(sys.modules, items)(): + for modname, module in getattr(sys.modules, items)(): if module is None: continue - for objname, obj in module.__dict__.items(): - modmap[objname].append((obj, name)) - return modmap - -def _lookup_module(modmap, name, obj, main_module): #FIXME: needs work - """lookup name if module is imported""" - for modobj, modname in modmap[name]: + if '.' not in modname: + modmap.top_level[id(module)] = modname + for objname, modobj in module.__dict__.items(): + modmap.by_name[objname].append((modobj, modname)) + modmap.by_id[id(modobj)].append((modobj, objname, modname)) + return modmap, modmap_byid + +def _lookup_module(modmap, name, obj, main_module): + """lookup name or id of obj if module is imported""" + for modobj, modname in modmap.by_name[name]: if modobj is obj and sys.modules[modname] is not main_module: - return modname + return modname, name + if isinstance(obj, SESSION_IMPORTED_AS_TYPES): + for modobj, objname, modname in modmap.by_id[id(obj)]: + if sys.modules[modname] is not main_module: + return modname, objname + return None, None def _stash_modules(main_module): modmap = _module_map() imported = [] + imported_as = [] + imported_top_level = [] # backwards compatibility original = {} items = 'items' if PY3 else 'iteritems' for name, obj in getattr(main_module.__dict__, items)(): - source_module = _lookup_module(modmap, name, obj, main_module) + source_module, objname = _lookup_module(modmap, name, obj, main_module) if source_module: - imported.append((source_module, name)) + if objname == name: + imported.append((source_module, name)) + else: + imported_as.append((source_module, objname, name)) else: - original[name] = obj + try: + imported_top_level.append((modmap.top_level[id(obj)], name)) + except KeyError: + original[name] = obj if len(imported): - import types - newmod = types.ModuleType(main_module.__name__) + newmod = ModuleType(main_module.__name__) newmod.__dict__.update(original) newmod.__dill_imported = imported + newmod.__dill_imported_as = imported_as + newmod.__dill_imported_top_level = imported_top_level return newmod else: return main_module def _restore_modules(main_module): - if '__dill_imported' not in main_module.__dict__: - return - imports = main_module.__dict__.pop('__dill_imported') - for module, name in imports: - exec("from %s import %s" % (module, name), main_module.__dict__) + try: + for modname, name in main_module.__dict__.pop('__dill_imported'): + main_module.__dict__[name] = __import__(modname, None, None, [name]).__dict__[name] + for modname, objname, name in main_module.__dict__.pop('__dill_imported_as'): + main_module.__dict__[name] = __import__(modname, None, None, [objname]).__dict__[objname] + for modname, name in main_module.__dict__.pop('__dill_imported_top_level'): + main_module.__dict__[name] = __import__(modname) + except KeyError: + pass #NOTE: 06/03/15 renamed main_module to main def dump_session(filename='/tmp/session.pkl', main=None, byref=False, **kwds):