-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: nstarman <[email protected]>
- Loading branch information
Showing
4 changed files
with
111 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
"""Test Array.""" | ||
|
||
import array | ||
import sys | ||
from enum import Enum, auto | ||
|
||
|
||
class ArrayLibraries(Enum): | ||
"""Array libraries.""" | ||
|
||
array = auto() | ||
|
||
numpy = auto() | ||
zarr = auto() | ||
dask = auto() | ||
xarray = auto() | ||
|
||
# ML | ||
jax = auto() | ||
torch = auto() | ||
tensorflow = auto() | ||
|
||
|
||
def get_array_from_library(name: ArrayLibraries) -> tuple[str, object]: # noqa: PLR0911 | ||
"""Import an array library.""" | ||
|
||
if name == ArrayLibraries.array: | ||
vrsn = "{}.{}.{}".format(*sys.version_info[:3]) | ||
|
||
return vrsn, array.array("f", [1.0, 2.0, 3.0]) | ||
|
||
elif name == ArrayLibraries.numpy: | ||
import numpy as np | ||
|
||
return np.__version__, np.linspace(0, 1, 10, dtype=np.float64) | ||
|
||
elif name == ArrayLibraries.jax: | ||
import jax | ||
import jax.numpy as jnp | ||
|
||
return jax.__version__, jnp.linspace(0, 1, 10) | ||
|
||
elif name == ArrayLibraries.torch: | ||
import torch | ||
|
||
return torch.__version__, torch.linspace(0, 1, 10, dtype=torch.float64) | ||
|
||
elif name == ArrayLibraries.zarr: | ||
import zarr | ||
|
||
return zarr.__version__, zarr.zeros((100, 100), chunks=(10, 10), dtype="f4")[:] | ||
|
||
elif name == ArrayLibraries.dask: | ||
import dask | ||
import dask.array as da | ||
|
||
return dask.__version__, da.linspace(0, 1, 10, dtype="float") | ||
|
||
elif name == ArrayLibraries.tensorflow: | ||
import tensorflow as tf | ||
|
||
return tf.__version__, tf.linspace(0, 1, 10) | ||
|
||
elif name == ArrayLibraries.xarray: | ||
import xarray as xr | ||
|
||
return xr.__version__, xr.DataArray([1.0, 2.0, 3.0]) | ||
|
||
|
||
if __name__ == "__main__": | ||
for library in ArrayLibraries: | ||
# import the array library | ||
vrsn, arr = get_array_from_library(library) | ||
|
||
# try adding a float | ||
try: | ||
_ = arr + 1.0 | ||
_ = 1.0 + arr | ||
except Exception: # noqa: BLE001 | ||
print(f"{library.name}-{vrsn} failed to add a float") # noqa: T201 | ||
else: | ||
print(f"{library.name}-{vrsn} added a float") # noqa: T201 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
.. _api-static_typing:.. _My target: | ||
|
||
############# | ||
Static Typing | ||
############# | ||
|
||
|
||
.. csv-table:: Example :rst:dir:`csv-table` | ||
:header: "Array Library", "Version", "Can add float" | ||
|
||
"NumPy", "1.24.1" , "Yes" | ||
"Zarr" , "2.13.6" , "Yes" | ||
"Dask" , "2023.1.1", "Yes" | ||
"Xarray", "2023.2.0", "Yes" | ||
"Jax" , "0.4.3", "Yes" | ||
"Torch", "1.13.1", "Yes" | ||
"Tensorflow", "2.11.0", "Yes" | ||
|
||
Last tested 2023-02-09. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters