diff --git a/pyopencl/__init__.py b/pyopencl/__init__.py index 15d1e2bb..fe1b4027 100644 --- a/pyopencl/__init__.py +++ b/pyopencl/__init__.py @@ -815,7 +815,6 @@ def kernel__setup(self, prg): warn_about_arg_count_bug=None, work_around_arg_count_bug=None, devs=self.context.devices) - self._wg_info_cache = {} return self def kernel_set_arg_types(self, arg_types): @@ -856,14 +855,19 @@ def kernel_set_arg_types(self, arg_types): devs=self.context.devices) def kernel_get_work_group_info(self, param, device): + try: + wg_info_cache = self._wg_info_cache + except AttributeError: + wg_info_cache = self._wg_info_cache = {} + cache_key = (param, device.int_ptr) try: - return self._wg_info_cache[cache_key] + return wg_info_cache[cache_key] except KeyError: pass result = kernel_old_get_work_group_info(self, param, device) - self._wg_info_cache[cache_key] = result + wg_info_cache[cache_key] = result return result def kernel_set_args(self, *args, **kwargs):