diff --git a/plugins/flytekit-omegaconf/flytekitplugins/omegaconf/dictconfig_transformer.py b/plugins/flytekit-omegaconf/flytekitplugins/omegaconf/dictconfig_transformer.py index c61f1e7fe8..0f2b8c63cc 100644 --- a/plugins/flytekit-omegaconf/flytekitplugins/omegaconf/dictconfig_transformer.py +++ b/plugins/flytekit-omegaconf/flytekitplugins/omegaconf/dictconfig_transformer.py @@ -120,16 +120,27 @@ def create_struct(type_map: Dict[str, str], value_map: Dict[str, Any], base_conf def parse_type_description(type_desc: str) -> Type: """Parse the type description and return the corresponding type.""" - if re.match(r".+\[.*]", type_desc): - origin_module, origin_type = type_desc.split("[")[0].rsplit(".", 1) - origin = importlib.import_module(origin_module).__getattribute__(origin_type) - sub_types = type_desc.split("[")[1][:-1].split(", ") - for i, t in enumerate(sub_types): - if t != "NoneType": - module_name, class_name = t.rsplit(".", 1) - sub_types[i] = importlib.import_module(module_name).__getattribute__(class_name) + generic_pattern = re.compile(r"(?P[^\[\]]+)\[(?P[^\[\]]+)\]") + match = generic_pattern.match(type_desc) + + if match: + origin_type = match.group("type") + args = match.group("args").split(", ") + + origin_module, origin_class = origin_type.rsplit(".", 1) + origin = importlib.import_module(origin_module).__getattribute__(origin_class) + + sub_types = [] + for arg in args: + if arg == "NoneType": + sub_types.append(type(None)) else: - sub_types[i] = type(None) + module_name, class_name = arg.rsplit(".", 1) + sub_type = importlib.import_module(module_name).__getattribute__(class_name) + sub_types.append(sub_type) + + if origin_class == "Optional": + return origin[sub_types[0]] return origin[tuple(sub_types)] else: module_name, class_name = type_desc.rsplit(".", 1)