From 0874c7d2a6f586b9923757292d1b8418856b785f Mon Sep 17 00:00:00 2001 From: Mostafa Hani <71686115+CatB1t@users.noreply.github.com> Date: Mon, 24 Oct 2022 05:27:41 +0200 Subject: [PATCH] Fix `supported_devices_and_dtypes` (#6081) remove hardcoded values with proper backend values. --- ivy/functional/ivy/device.py | 24 +++++++++--------------- ivy/functional/ivy/general.py | 8 +------- 2 files changed, 10 insertions(+), 22 deletions(-) diff --git a/ivy/functional/ivy/device.py b/ivy/functional/ivy/device.py index 520163107c112..cb3b86efee758 100644 --- a/ivy/functional/ivy/device.py +++ b/ivy/functional/ivy/device.py @@ -290,9 +290,7 @@ def num_ivy_arrays_on_dev(device: Union[ivy.Device, ivy.NativeDevice], /) -> int @handle_nestable @handle_exceptions def print_all_ivy_arrays_on_dev( - *, - device: Union[ivy.Device, ivy.NativeDevice] = None, - attr_only: bool = True + *, device: Union[ivy.Device, ivy.NativeDevice] = None, attr_only: bool = True ) -> None: """ Prints the shape and dtype for all ivy arrays which are currently alive on the @@ -1115,30 +1113,26 @@ def _is_valid_devices_attributes(fn: Callable) -> bool: def _get_devices(fn, complement=True): - # TODO: Not hardcode this - VALID_DEVICES = ("cpu",) - INVALID_DEVICES = ( - "gpu", - "tpu", - ) - ALL_DEVICES = VALID_DEVICES + INVALID_DEVICES + valid_devices = ivy.valid_devices + invalid_devices = ivy.invalid_devices + all_devices = ivy.all_devices - supported = set(VALID_DEVICES) + supported = set(ivy.valid_devices) is_backend_fn = "backend" in fn.__module__ is_frontend_fn = "frontend" in fn.__module__ is_einops_fn = "einops" in fn.__name__ if not is_backend_fn and not is_frontend_fn and not is_einops_fn: if complement: - supported = set(ALL_DEVICES).difference(supported) + supported = set(all_devices).difference(supported) return supported # Their values are formated like either # 1. fn.supported_devices = ("cpu",) # Could also have the "all" value for the framework basic = [ - ("supported_devices", set.intersection, VALID_DEVICES), - ("unsupported_devices", set.difference, INVALID_DEVICES), + ("supported_devices", set.intersection, valid_devices), + ("unsupported_devices", set.difference, invalid_devices), ] for (key, merge_fn, base) in basic: if hasattr(fn, key): @@ -1149,7 +1143,7 @@ def _get_devices(fn, complement=True): supported = merge_fn(supported, set(v)) if complement: - supported = set(ALL_DEVICES).difference(supported) + supported = set(all_devices).difference(supported) return tuple(supported) diff --git a/ivy/functional/ivy/general.py b/ivy/functional/ivy/general.py index 09df3d5f910b1..54292e4cb18f3 100644 --- a/ivy/functional/ivy/general.py +++ b/ivy/functional/ivy/general.py @@ -3200,14 +3200,8 @@ def _is_valid_device_and_dtypes_attributes(fn: Callable) -> bool: def _all_dnd_combinations(): - # TODO: not hard code this - - VALID_DEVICES = ("cpu",) - INVALID_DEVICES = ("gpu", "tpu") - ALL_DEVICES = VALID_DEVICES + INVALID_DEVICES - all_comb = {} - for device in ALL_DEVICES: + for device in ivy.all_devices: all_comb[device] = ivy.all_dtypes return all_comb