diff --git a/jax/_src/source_info_util.py b/jax/_src/source_info_util.py index efec6cd2e561..4aef332eeb97 100644 --- a/jax/_src/source_info_util.py +++ b/jax/_src/source_info_util.py @@ -147,11 +147,12 @@ def raw_frame_to_frame(code: types.CodeType, lasti: int) -> Frame: loc = xla_client.Traceback.code_addr2location(code, lasti) start_line, start_column, end_line, end_column = loc return Frame(file_name=code.co_filename, - function_name=code.co_name, + function_name=code.co_qualname, start_line=start_line, start_column=start_column, end_line=end_line, end_column=end_column) else: def raw_frame_to_frame(code: types.CodeType, lasti: int) -> Frame: + # pre-3.11 co_qualname does not exist, use co_name return Frame(file_name=code.co_filename, function_name=code.co_name, start_line=xla_client.Traceback.code_addr2line(code, lasti),