Skip to content

Commit

Permalink
Fix supported_devices_and_dtypes (#6081)
Browse files Browse the repository at this point in the history
remove hardcoded values with proper backend values.
  • Loading branch information
CatB1t authored Oct 24, 2022
1 parent e1d1105 commit 0874c7d
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 22 deletions.
24 changes: 9 additions & 15 deletions ivy/functional/ivy/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,9 +290,7 @@ def num_ivy_arrays_on_dev(device: Union[ivy.Device, ivy.NativeDevice], /) -> int
@handle_nestable
@handle_exceptions
def print_all_ivy_arrays_on_dev(
*,
device: Union[ivy.Device, ivy.NativeDevice] = None,
attr_only: bool = True
*, device: Union[ivy.Device, ivy.NativeDevice] = None, attr_only: bool = True
) -> None:
"""
Prints the shape and dtype for all ivy arrays which are currently alive on the
Expand Down Expand Up @@ -1115,30 +1113,26 @@ def _is_valid_devices_attributes(fn: Callable) -> bool:


def _get_devices(fn, complement=True):
# TODO: Not hardcode this
VALID_DEVICES = ("cpu",)
INVALID_DEVICES = (
"gpu",
"tpu",
)
ALL_DEVICES = VALID_DEVICES + INVALID_DEVICES
valid_devices = ivy.valid_devices
invalid_devices = ivy.invalid_devices
all_devices = ivy.all_devices

supported = set(VALID_DEVICES)
supported = set(ivy.valid_devices)

is_backend_fn = "backend" in fn.__module__
is_frontend_fn = "frontend" in fn.__module__
is_einops_fn = "einops" in fn.__name__
if not is_backend_fn and not is_frontend_fn and not is_einops_fn:
if complement:
supported = set(ALL_DEVICES).difference(supported)
supported = set(all_devices).difference(supported)
return supported

# Their values are formated like either
# 1. fn.supported_devices = ("cpu",)
# Could also have the "all" value for the framework
basic = [
("supported_devices", set.intersection, VALID_DEVICES),
("unsupported_devices", set.difference, INVALID_DEVICES),
("supported_devices", set.intersection, valid_devices),
("unsupported_devices", set.difference, invalid_devices),
]
for (key, merge_fn, base) in basic:
if hasattr(fn, key):
Expand All @@ -1149,7 +1143,7 @@ def _get_devices(fn, complement=True):
supported = merge_fn(supported, set(v))

if complement:
supported = set(ALL_DEVICES).difference(supported)
supported = set(all_devices).difference(supported)

return tuple(supported)

Expand Down
8 changes: 1 addition & 7 deletions ivy/functional/ivy/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -3200,14 +3200,8 @@ def _is_valid_device_and_dtypes_attributes(fn: Callable) -> bool:


def _all_dnd_combinations():
# TODO: not hard code this

VALID_DEVICES = ("cpu",)
INVALID_DEVICES = ("gpu", "tpu")
ALL_DEVICES = VALID_DEVICES + INVALID_DEVICES

all_comb = {}
for device in ALL_DEVICES:
for device in ivy.all_devices:
all_comb[device] = ivy.all_dtypes
return all_comb

Expand Down

0 comments on commit 0874c7d

Please sign in to comment.