diff --git a/dill/_dill.py b/dill/_dill.py index 13b2b246..eccfebba 100644 --- a/dill/_dill.py +++ b/dill/_dill.py @@ -397,50 +397,75 @@ def loads(str, ignore=None, **kwds): ### End: Shorthands ### ### Pickle the Interpreter Session +SESSION_IMPORTED_AS_TYPES = (ModuleType, ClassType, TypeType, Exception, + FunctionType, MethodType, BuiltinMethodType) + def _module_map(): """get map of imported modules""" - from collections import defaultdict - modmap = defaultdict(list) + from collections import defaultdict, namedtuple + modmap = namedtuple('Modmap', ['by_module', 'by_name', 'top_level']) + modmap = modmap(defaultdict(list), defaultdict(list), {}) items = 'items' if PY3 else 'iteritems' - for name, module in getattr(sys.modules, items)(): + for modname, module in getattr(sys.modules, items)(): if module is None: continue - for objname, obj in module.__dict__.items(): - modmap[objname].append((obj, name)) - return modmap - -def _lookup_module(modmap, name, obj, main_module): #FIXME: needs work - """lookup name if module is imported""" - for modobj, modname in modmap[name]: + if '.' not in modname: + modmap.top_level[id(module)] = modname + for objname, modobj in module.__dict__.items(): + modmap.by_name[objname].append((modobj, modname)) + modmap.by_id[id(modobj)].append((modobj, objname, modname)) + return modmap, modmap_byid + +def _lookup_module(modmap, name, obj, main_module): + """lookup name or id of obj if module is imported""" + for modobj, modname in modmap.by_name[name]: if modobj is obj and sys.modules[modname] is not main_module: - return modname + return modname, name + if isinstance(obj, SESSION_IMPORTED_AS_TYPES): + for modobj, objname, modname in modmap.by_id[id(obj)]: + if sys.modules[modname] is not main_module: + return modname, objname + return None, None def _stash_modules(main_module): modmap = _module_map() imported = [] + imported_as = [] + imported_top_level = [] # backwards compatibility original = {} items = 'items' if PY3 else 'iteritems' for name, obj in getattr(main_module.__dict__, items)(): - source_module = _lookup_module(modmap, name, obj, main_module) + source_module, objname = _lookup_module(modmap, name, obj, main_module) if source_module: - imported.append((source_module, name)) + if objname == name: + imported.append((source_module, name)) + else: + imported_as.append((source_module, objname, name)) else: - original[name] = obj + try: + imported_top_level.append((modmap.top_level[id(obj)], name)) + except KeyError: + original[name] = obj if len(imported): - import types - newmod = types.ModuleType(main_module.__name__) + newmod = ModuleType(main_module.__name__) newmod.__dict__.update(original) newmod.__dill_imported = imported + newmod.__dill_imported_as = imported_as + newmod.__dill_imported_top_level = imported_top_level return newmod else: return main_module def _restore_modules(main_module): - if '__dill_imported' not in main_module.__dict__: - return - imports = main_module.__dict__.pop('__dill_imported') - for module, name in imports: - exec("from %s import %s" % (module, name), main_module.__dict__) + try: + for modname, name in main_module.__dict__.pop('__dill_imported'): + main_module.__dict__[name] = __import__(modname, None, None, [name]).__dict__[name] + for modname, objname, name in main_module.__dict__.pop('__dill_imported_as'): + main_module.__dict__[name] = __import__(modname, None, None, [objname]).__dict__[objname] + for modname, name in main_module.__dict__.pop('__dill_imported_top_level'): + main_module.__dict__[name] = __import__(modname) + except KeyError: + pass #NOTE: 06/03/15 renamed main_module to main def dump_session(filename='/tmp/session.pkl', main=None, byref=False, **kwds):