diff --git a/aiida_workgraph/utils/__init__.py b/aiida_workgraph/utils/__init__.py index cab8bbb3..885b1396 100644 --- a/aiida_workgraph/utils/__init__.py +++ b/aiida_workgraph/utils/__init__.py @@ -535,8 +535,10 @@ def add_imports(type_hint): type_hint, "__origin__" ): # This checks for higher-order types like List, Dict module_name = type_hint.__module__ - type_name = type_hint._name - for arg in type_hint.__args__: + type_name = ( + getattr(type_hint, "_name", None) or type_hint.__origin__.__name__ + ) + for arg in getattr(type_hint, "__args__", []): if arg is type(None): # noqa: E721 continue add_imports(arg) # Recursively add imports for each argument diff --git a/tests/test_python.py b/tests/test_python.py index 805c56c4..579a30ab 100644 --- a/tests/test_python.py +++ b/tests/test_python.py @@ -66,41 +66,30 @@ def add(x, **kwargs): assert wg.tasks["add"].outputs["result"].value.value == 6 -def test_PythonJob_typing(fixture_localhost): +def test_PythonJob_typing(): """Test function with typing.""" from numpy import array - - def add(x: array, y: array) -> array: - return x + y - - def multiply(x: Any, y: Any) -> Any: - return x * y - - wg = WorkGraph("test_PythonJob") - wg.add_task("PythonJob", function=add, name="add") - wg.add_task( - "PythonJob", function=multiply, name="multiply", x=wg.tasks["add"].outputs[0] - ) - # - metadata = { - "options": { - "custom_scheduler_commands": "# test", - # "custom_scheduler_commands": 'module load anaconda\nconda activate py3.11\n', - } + from ase import Atoms + from aiida_workgraph.utils import get_required_imports + from typing import List + + def generate_structures( + structures: List[Atoms], + strain_lst: list, + data: array, + strain_lst1: list = None, + data1: array = None, + structure1: Atoms = None, + ) -> list[Atoms]: + pass + + modules = get_required_imports(generate_structures) + assert modules == { + "ase.atoms": {"Atoms"}, + "typing": {"List"}, + "builtins": {"list"}, + "numpy": {"array"}, } - wg.run( - inputs={ - "add": { - "x": array([1, 2]), - "y": array([2, 3]), - "computer": "localhost", - "metadata": metadata, - }, - "multiply": {"y": 4, "computer": "localhost", "metadata": metadata}, - }, - # wait=True, - ) - assert (wg.tasks["multiply"].outputs["result"].value.value == array([12, 20])).all() def test_PythonJob_outputs(fixture_localhost):