diff --git a/py/server/deephaven/dtypes.py b/py/server/deephaven/dtypes.py index d7983d149b5..eedd5898639 100644 --- a/py/server/deephaven/dtypes.py +++ b/py/server/deephaven/dtypes.py @@ -389,7 +389,14 @@ def _component_np_dtype_char(t: type) -> Optional[str]: if not component_type and sys.version_info.minor > 8: import types if isinstance(t, types.GenericAlias) and (issubclass(t.__origin__, Sequence) or t.__origin__ == np.ndarray): - component_type = t.__args__[1].__args__[0] if t.__args__[0] == typing.Any else t.__args__[0] + nargs = len(t.__args__) + if nargs == 1: + component_type = t.__args__[0] + elif nargs == 2: # for npt.NDArray[np.int64], etc. + a0 = t.__args__[0] + a1 = t.__args__[1] + if a0 == typing.Any and isinstance(a1, types.GenericAlias): + component_type = a1.__args__[0] if component_type: return _np_dtype_char(component_type) diff --git a/py/server/tests/test_pyfunc_return_java_values.py b/py/server/tests/test_pyfunc_return_java_values.py index 9e245c8eec8..9fdc1873342 100644 --- a/py/server/tests/test_pyfunc_return_java_values.py +++ b/py/server/tests/test_pyfunc_return_java_values.py @@ -2,11 +2,13 @@ # Copyright (c) 2016-2023 Deephaven Data Labs and Patent Pending # import datetime +import typing import unittest from typing import List, Union, Tuple, Sequence import numba as nb import numpy as np +import numpy.typing as npt import pandas as pd from deephaven import empty_table, dtypes, DHError @@ -211,14 +213,31 @@ def f() -> np.ndarray[np.int64]: self.assertIn("not support multi-dimensional arrays", str(cm.exception)) def test_npt_NDArray_return_type(self): - import numpy.typing as npt - def f() -> npt.NDArray[np.int64]: return np.array([1, 2], dtype=np.int64) t = empty_table(10).update(["X1 = f()"]) self.assertEqual(t.columns[0].data_type, dtypes.long_array) + def test_ndarray_weird_cases(self): + def f() -> np.ndarray[typing.Any]: + return np.array([1, 2], dtype=np.int64) + + t = empty_table(10).update(["X1 = f()"]) + self.assertEqual(t.columns[0].data_type, dtypes.PyObject) + + def f1() -> npt.NDArray[typing.Any]: + return np.array([1, 2], dtype=np.int64) + + t = empty_table(10).update(["X1 = f1()"]) + self.assertEqual(t.columns[0].data_type, dtypes.PyObject) + + def f2() -> np.ndarray[typing.Any, np.int64]: + return np.array([1, 2], dtype=np.int64) + + t = empty_table(10).update(["X1 = f2()"]) + self.assertEqual(t.columns[0].data_type, dtypes.PyObject) + if __name__ == '__main__': unittest.main()