Skip to content

Commit

Permalink
Fortity numpy array typehint parsing (deephaven#4671)
Browse files Browse the repository at this point in the history
  • Loading branch information
jmao-denver authored Oct 20, 2023
1 parent 044d059 commit 8e4d243
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 3 deletions.
9 changes: 8 additions & 1 deletion py/server/deephaven/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,14 @@ def _component_np_dtype_char(t: type) -> Optional[str]:
if not component_type and sys.version_info.minor > 8:
import types
if isinstance(t, types.GenericAlias) and (issubclass(t.__origin__, Sequence) or t.__origin__ == np.ndarray):
component_type = t.__args__[1].__args__[0] if t.__args__[0] == typing.Any else t.__args__[0]
nargs = len(t.__args__)
if nargs == 1:
component_type = t.__args__[0]
elif nargs == 2: # for npt.NDArray[np.int64], etc.
a0 = t.__args__[0]
a1 = t.__args__[1]
if a0 == typing.Any and isinstance(a1, types.GenericAlias):
component_type = a1.__args__[0]

if component_type:
return _np_dtype_char(component_type)
Expand Down
23 changes: 21 additions & 2 deletions py/server/tests/test_pyfunc_return_java_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
# Copyright (c) 2016-2023 Deephaven Data Labs and Patent Pending
#
import datetime
import typing
import unittest
from typing import List, Union, Tuple, Sequence

import numba as nb
import numpy as np
import numpy.typing as npt
import pandas as pd

from deephaven import empty_table, dtypes, DHError
Expand Down Expand Up @@ -211,14 +213,31 @@ def f() -> np.ndarray[np.int64]:
self.assertIn("not support multi-dimensional arrays", str(cm.exception))

def test_npt_NDArray_return_type(self):
import numpy.typing as npt

def f() -> npt.NDArray[np.int64]:
return np.array([1, 2], dtype=np.int64)

t = empty_table(10).update(["X1 = f()"])
self.assertEqual(t.columns[0].data_type, dtypes.long_array)

def test_ndarray_weird_cases(self):
def f() -> np.ndarray[typing.Any]:
return np.array([1, 2], dtype=np.int64)

t = empty_table(10).update(["X1 = f()"])
self.assertEqual(t.columns[0].data_type, dtypes.PyObject)

def f1() -> npt.NDArray[typing.Any]:
return np.array([1, 2], dtype=np.int64)

t = empty_table(10).update(["X1 = f1()"])
self.assertEqual(t.columns[0].data_type, dtypes.PyObject)

def f2() -> np.ndarray[typing.Any, np.int64]:
return np.array([1, 2], dtype=np.int64)

t = empty_table(10).update(["X1 = f2()"])
self.assertEqual(t.columns[0].data_type, dtypes.PyObject)


if __name__ == '__main__':
unittest.main()

0 comments on commit 8e4d243

Please sign in to comment.