diff --git a/ChangeLog b/ChangeLog index 61773a4576..19e516f7ea 100644 --- a/ChangeLog +++ b/ChangeLog @@ -13,6 +13,8 @@ What's New in astroid 2.7.2? Release date: TBA * ``BaseContainer`` is now public, and will replace ``_BaseContainer`` completely in astroid 3.0. +* The call cache used by inference functions produced by ``inference_tip`` + can now be cleared via ``clear_inference_tip_cache``. What's New in astroid 2.7.1? diff --git a/astroid/inference_tip.py b/astroid/inference_tip.py index d97a9fc9cd..2a7adcd6f9 100644 --- a/astroid/inference_tip.py +++ b/astroid/inference_tip.py @@ -3,36 +3,35 @@ """Transform utilities (filters and decorator)""" -import itertools import typing import wrapt -# pylint: disable=dangerous-default-value from astroid.exceptions import InferenceOverwriteError from astroid.nodes import NodeNG +InferFn = typing.Callable[..., typing.Any] + +_cache: typing.Dict[typing.Tuple[InferFn, NodeNG], typing.Any] = {} + + +def clear_inference_tip_cache(): + """Clear the inference tips cache.""" + _cache.clear() + @wrapt.decorator -def _inference_tip_cached(func, instance, args, kwargs, _cache={}): # noqa:B006 +def _inference_tip_cached(func, instance, args, kwargs): """Cache decorator used for inference tips""" node = args[0] try: - return iter(_cache[func, node]) + result = _cache[func, node] except KeyError: - result = func(*args, **kwargs) - # Need to keep an iterator around - original, copy = itertools.tee(result) - _cache[func, node] = list(copy) - return original - - -# pylint: enable=dangerous-default-value + result = _cache[func, node] = list(func(*args, **kwargs)) + return iter(result) -def inference_tip( - infer_function: typing.Callable, raise_on_overwrite: bool = False -) -> typing.Callable: +def inference_tip(infer_function: InferFn, raise_on_overwrite: bool = False) -> InferFn: """Given an instance specific inference function, return a function to be given to AstroidManager().register_transform to set this inference function. @@ -54,9 +53,7 @@ def inference_tip( excess overwrites. """ - def transform( - node: NodeNG, infer_function: typing.Callable = infer_function - ) -> NodeNG: + def transform(node: NodeNG, infer_function: InferFn = infer_function) -> NodeNG: if ( raise_on_overwrite and node._explicit_inference is not None