From 68ac334a2d2e554a49a07d09db215f1b55509a85 Mon Sep 17 00:00:00 2001 From: Christopher Bignamini Date: Tue, 11 Jul 2023 12:45:32 +0200 Subject: [PATCH 1/4] WIP: Multiline declaration handling fix created --- tools/src/icon4pytools/f2ser/parse.py | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/tools/src/icon4pytools/f2ser/parse.py b/tools/src/icon4pytools/f2ser/parse.py index 28fdacb725..26275286b7 100644 --- a/tools/src/icon4pytools/f2ser/parse.py +++ b/tools/src/icon4pytools/f2ser/parse.py @@ -21,7 +21,6 @@ from icon4pytools.f2ser.exceptions import MissingDerivedTypeError, ParsingError - def crack(path: Path) -> dict: return crackfortran(path)[0] @@ -293,12 +292,25 @@ def get_subroutine_lines(self, subroutine_name: str) -> CodegenContext: end_subroutine_ln = code[: end_match.start()].count("\n") + 1 # Find the last intent statement line number in the subroutine - declaration_pattern = r".*::\s*(\w+\b)" - declaration_pattern_lines = [ - i - for i, line in enumerate(code.splitlines()[start_subroutine_ln:end_subroutine_ln]) - if re.search(declaration_pattern, line) - ] + declaration_pattern = r".*::\s*(\w+\b)|.*::.*(\&)" + is_multiline_declaration = False + declaration_pattern_lines = [] + for i, line in enumerate(code.splitlines()[start_subroutine_ln:end_subroutine_ln]): + if is_multiline_declaration == False: + if re.search(declaration_pattern, line): + # this is a declaration line, don't know if single or multiline + declaration_pattern_lines.append(i) + if line.find("&") != -1: + # this is a multiline declaration block + is_multiline_declaration = True + else: + if is_multiline_declaration == True: + # this is the continuation of a multiline declaration block + declaration_pattern_lines.append(i) + if line.find("&") == -1: + # this is the last line of a multiline declaration block + is_multiline_declaration = False + if not declaration_pattern_lines: raise ParsingError(f"No declarations found in {self.granule_path}") last_declaration_ln = declaration_pattern_lines[-1] + start_subroutine_ln + 1 From 494d93dc7c1e531732a861a202c59e13553e6579 Mon Sep 17 00:00:00 2001 From: samkellerhals Date: Tue, 11 Jul 2023 14:39:09 +0200 Subject: [PATCH 2/4] Add docstrings and small cleanup --- tools/src/icon4pytools/f2ser/parse.py | 134 ++++-- .../fortran_samples/multiline_example.f90 | 423 ++++++++++++++++++ tools/tests/f2ser/test_parsing.py | 13 + 3 files changed, 530 insertions(+), 40 deletions(-) create mode 100644 tools/tests/f2ser/fortran_samples/multiline_example.f90 diff --git a/tools/src/icon4pytools/f2ser/parse.py b/tools/src/icon4pytools/f2ser/parse.py index 26275286b7..a8c4b36a3e 100644 --- a/tools/src/icon4pytools/f2ser/parse.py +++ b/tools/src/icon4pytools/f2ser/parse.py @@ -21,6 +21,7 @@ from icon4pytools.f2ser.exceptions import MissingDerivedTypeError, ParsingError + def crack(path: Path) -> dict: return crackfortran(path)[0] @@ -65,9 +66,38 @@ def __init__(self, granule: Path, dependencies: Optional[list[Path]] = None) -> def __call__(self) -> ParsedGranule: """Parse the granule and return the parsed data.""" subroutines = self.parse_subroutines() - last_import_ln = self.find_last_fortran_use_statement() + last_import_ln = self._find_last_fortran_use_statement() return ParsedGranule(subroutines=subroutines, last_import_ln=last_import_ln) + def _find_last_fortran_use_statement(self) -> Optional[int]: + """Finds the line number of the last Fortran USE statement in the code. + + Returns: + int: the line number of the last USE statement, or None if no USE statement is found. + """ + # Reverse the order of the lines so we can search from the end + code = self._read_code_from_file() + code_lines = code.splitlines() + code_lines.reverse() + + # Look for the last USE statement + use_ln = None + for i, line in enumerate(code_lines): + if line.strip().lower().startswith("use"): + use_ln = len(code_lines) - i + if i > 0 and code_lines[i - 1].strip().lower() == "#endif": + # If the USE statement is preceded by an #endif statement, return the line number after the #endif statement + return use_ln + 1 + else: + return use_ln + return None + + def _read_code_from_file(self) -> str: + """Reads the content of the granule and returns it as a string.""" + with open(self.granule_path) as f: + code = f.read() + return code + def parse_subroutines(self): subroutines = self._extract_subroutines(crack(self.granule_path)) variables_grouped_by_intent = { @@ -245,31 +275,12 @@ def _update_with_codegen_lines(self, parsed_types: dict) -> dict: with_lines = deepcopy(parsed_types) for subroutine in with_lines: for intent in with_lines[subroutine]: - with_lines[subroutine][intent]["codegen_ctx"] = self.get_subroutine_lines( + with_lines[subroutine][intent]["codegen_ctx"] = self._get_subroutine_lines( subroutine ) return with_lines - def find_last_fortran_use_statement(self): - with open(self.granule_path) as f: - file_contents = f.readlines() - - # Reverse the order of the lines so we can search from the end - file_contents.reverse() - - # Look for the last USE statement - use_ln = None - for i, line in enumerate(file_contents): - if line.strip().lower().startswith("use"): - use_ln = len(file_contents) - i - if i > 0 and file_contents[i - 1].strip().lower() == "#endif": - # If the USE statement is preceded by an #endif statement, return the line number after the #endif statement - return use_ln + 1 - else: - return use_ln - return None - - def get_subroutine_lines(self, subroutine_name: str) -> CodegenContext: + def _get_subroutine_lines(self, subroutine_name: str) -> CodegenContext: """Return CodegenContext object containing line numbers of the last declaration statement and the code before the end of the given subroutine. Args: @@ -278,10 +289,33 @@ def get_subroutine_lines(self, subroutine_name: str) -> CodegenContext: Returns: CodegenContext: Object containing the line number of the last declaration statement and the line number of the last line of the code before the end of the given subroutine. """ - with open(self.granule_path) as f: - code = f.read() + code = self._read_code_from_file() + + start_subroutine_ln, end_subroutine_ln = self._find_subroutine_lines(code, subroutine_name) + + variable_declaration_ln = self._find_variable_declarations(code, start_subroutine_ln, end_subroutine_ln) + + if not variable_declaration_ln: + raise ParsingError(f"No variable declarations found in {self.granule_path}") + + first_declaration_ln, last_declaration_ln = self._get_variable_declaration_bounds(variable_declaration_ln, + start_subroutine_ln) + + pre_end_subroutine_ln = end_subroutine_ln - 1 # we want to generate the code before the end of the subroutine + + return CodegenContext(first_declaration_ln, last_declaration_ln, pre_end_subroutine_ln) - # Find the line number where the subroutine is defined + @staticmethod + def _find_subroutine_lines(code: str, subroutine_name: str) -> tuple[int]: + """Finds line numbers of a subroutine within a code block. + + Args: + code (str): The code block to search for the subroutine. + subroutine_name (str): Name of the subroutine to find. + + Returns: + tuple: Line numbers of the start and end of the subroutine. + """ start_subroutine_pattern = r"SUBROUTINE\s+" + subroutine_name + r"\s*\(" end_subroutine_pattern = r"END\s+SUBROUTINE\s+" + subroutine_name + r"\s*" start_match = re.search(start_subroutine_pattern, code) @@ -290,34 +324,54 @@ def get_subroutine_lines(self, subroutine_name: str) -> CodegenContext: return None start_subroutine_ln = code[: start_match.start()].count("\n") + 1 end_subroutine_ln = code[: end_match.start()].count("\n") + 1 + return start_subroutine_ln, end_subroutine_ln + + @staticmethod + def _find_variable_declarations(code: str, start_subroutine_ln: int, end_subroutine_ln: int) -> list: + """Finds line numbers of variable declarations within a code block. - # Find the last intent statement line number in the subroutine + Args: + code (str): The code block to search for variable declarations. + start_subroutine_ln (int): Starting line number of the subroutine. + end_subroutine_ln (int): Ending line number of the subroutine. + + Returns: + list: Line numbers of variable declaration lines. + + This method identifies single-line and multiline variable declarations within + the specified code block, delimited by the start and end line numbers of the + subroutine. Multiline declarations are detected by the presence of an ampersand + character ('&') at the end of a line. + """ declaration_pattern = r".*::\s*(\w+\b)|.*::.*(\&)" is_multiline_declaration = False declaration_pattern_lines = [] + for i, line in enumerate(code.splitlines()[start_subroutine_ln:end_subroutine_ln]): - if is_multiline_declaration == False: + if not is_multiline_declaration: if re.search(declaration_pattern, line): - # this is a declaration line, don't know if single or multiline declaration_pattern_lines.append(i) if line.find("&") != -1: - # this is a multiline declaration block is_multiline_declaration = True else: - if is_multiline_declaration == True: - # this is the continuation of a multiline declaration block + if is_multiline_declaration: declaration_pattern_lines.append(i) if line.find("&") == -1: - # this is the last line of a multiline declaration block is_multiline_declaration = False - if not declaration_pattern_lines: - raise ParsingError(f"No declarations found in {self.granule_path}") - last_declaration_ln = declaration_pattern_lines[-1] + start_subroutine_ln + 1 - first_declaration_ln = declaration_pattern_lines[0] + start_subroutine_ln + return declaration_pattern_lines - pre_end_subroutine_ln = ( - end_subroutine_ln - 1 - ) # we want to generate the code before the end of the subroutine + @staticmethod + def _get_variable_declaration_bounds(declaration_pattern_lines: list, start_subroutine_ln: int) -> tuple: + """Returns the line numbers of the bounds for a variable declaration block. - return CodegenContext(first_declaration_ln, last_declaration_ln, pre_end_subroutine_ln) + Args: + declaration_pattern_lines (list): List of line numbers representing the relative positions of lines within the declaration block. + start_subroutine_ln (int): Line number indicating the starting line of the subroutine. + + Returns: + tuple: Line number of the first declaration line, line number following the last declaration line. + """ + first_declaration_ln = declaration_pattern_lines[0] + start_subroutine_ln + last_declaration_ln = declaration_pattern_lines[-1] + start_subroutine_ln + 1 + return first_declaration_ln, last_declaration_ln diff --git a/tools/tests/f2ser/fortran_samples/multiline_example.f90 b/tools/tests/f2ser/fortran_samples/multiline_example.f90 new file mode 100644 index 0000000000..c3d403310f --- /dev/null +++ b/tools/tests/f2ser/fortran_samples/multiline_example.f90 @@ -0,0 +1,423 @@ +MODULE mo_graupel_granule + + USE, INTRINSIC :: iso_fortran_env, ONLY: wp => real64 + USE gscp_graupel, ONLY: graupel + + IMPLICIT NONE + + PRIVATE :: graupel_parameters, params, mma,mmb + PUBLIC :: graupel_init, graupel_run + + TYPE graupel_parameters + ! gscp_data + INTEGER :: iautocon + INTEGER :: isnow_n0temp + REAL(wp) :: ccsrim + REAL(wp) :: ccsagg + REAL(wp) :: ccsdep + REAL(wp) :: ccsvel + REAL(wp) :: ccsvxp + REAL(wp) :: ccslam + REAL(wp) :: ccslxp + REAL(wp) :: ccsaxp + REAL(wp) :: ccsdxp + REAL(wp) :: ccshi1 + REAL(wp) :: ccdvtp + REAL(wp) :: ccidep + REAL(wp) :: ccswxp + REAL(wp) :: zconst + REAL(wp) :: zcev + REAL(wp) :: zbev + REAL(wp) :: zcevxp + REAL(wp) :: zbevxp + REAL(wp) :: zvzxp + REAL(wp) :: zvz0r + REAL(wp) :: v0snow + REAL(wp) :: x13o8 + REAL(wp) :: x1o2 + REAL(wp) :: x27o16 + REAL(wp) :: x3o4 + REAL(wp) :: x7o4 + REAL(wp) :: x7o8 + REAL(wp) :: zbvi + REAL(wp) :: zcac + REAL(wp) :: zccau + REAL(wp) :: zciau + REAL(wp) :: zcicri + REAL(wp) :: zcrcri + REAL(wp) :: zcrfrz + REAL(wp) :: zcrfrz1 + REAL(wp) :: zcrfrz2 + REAL(wp) :: zeps + REAL(wp) :: zkcac + REAL(wp) :: zkphi1 + REAL(wp) :: zkphi2 + REAL(wp) :: zkphi3 + REAL(wp) :: zmi0 + REAL(wp) :: zmimax + REAL(wp) :: zmsmin + REAL(wp) :: zn0s0 + REAL(wp) :: zn0s1 + REAL(wp) :: zn0s2 + REAL(wp) :: znimax_thom + REAL(wp) :: zqmin + REAL(wp) :: zrho0 + REAL(wp) :: zthet + REAL(wp) :: zthn + REAL(wp) :: ztmix + REAL(wp) :: ztrfrz + REAL(wp) :: zvz0i + REAL(wp) :: icesedi_exp + REAL(wp) :: zams + REAL(wp) :: dist_cldtop_ref + REAL(wp) :: reduce_dep_ref + REAL(wp) :: tmin_iceautoconv + REAL(wp) :: zceff_fac + REAL(wp) :: zceff_min + REAL(wp) :: v_sedi_rain_min + REAL(wp) :: v_sedi_snow_min + REAL(wp) :: v_sedi_graupel_min + + ! mo_physical constants + REAL(wp) :: r_v + REAL(wp) :: lh_v + REAL(wp) :: lh_s + REAL(wp) :: cpdr + REAL(wp) :: cvdr + REAL(wp) :: b3 + REAL(wp) :: t0 + END TYPE graupel_parameters + + + + REAL(wp):: mma(10),mmb(10) + TYPE(graupel_parameters) :: params + +CONTAINS + + SUBROUTINE graupel_init( & + ccsrim, ccsagg, ccsdep, ccsvel, ccsvxp, ccslam, & + ccslxp, ccsaxp, ccsdxp, ccshi1, ccdvtp, ccidep, & + ccswxp, zconst, zcev, zbev, zcevxp, zbevxp, & + zvzxp, zvz0r, & + v0snow, & + x13o8, x1o2, x27o16, x3o4, x7o4, x7o8, & + zbvi, zcac, zccau, zciau, zcicri, & + zcrcri, zcrfrz, zcrfrz1, zcrfrz2, zeps, zkcac, & + zkphi1, zkphi2, zkphi3, zmi0, zmimax, zmsmin, & + zn0s0, zn0s1, zn0s2, znimax_thom, zqmin, & + zrho0, zthet, zthn, ztmix, ztrfrz, & + zvz0i, icesedi_exp, zams, & + iautocon, isnow_n0temp, dist_cldtop_ref, reduce_dep_ref, & + tmin_iceautoconv, zceff_fac, zceff_min, & + mma_driver, mmb_driver, v_sedi_rain_min, v_sedi_snow_min, v_sedi_graupel_min, & + r_v , & !> gas constant for water vapour + lh_v, & !! latent heat of vapourization + lh_s, & !! latent heat of sublimation + cpdr, & !! (spec. heat of dry air at constant press)^-1 + cvdr, & !! (spec. heat of dry air at const vol)^-1 + b3, & !! melting temperature of ice/snow + t0) !! melting temperature of ice/snow + + INTEGER , INTENT(IN) :: iautocon,isnow_n0temp + REAL(wp), INTENT(IN) :: ccsrim, & + ccsagg, ccsdep, ccsvel, ccsvxp, ccslam, & + ccslxp, ccsaxp, ccsdxp, ccshi1, ccdvtp, ccidep, & + ccswxp, zconst, zcev, zbev, zcevxp, zbevxp, & + zvzxp, zvz0r, & + v0snow, & + x13o8, x1o2, x27o16, x3o4, x7o4, x7o8, & + zbvi, zcac, zccau, zciau, zcicri, & + zcrcri, zcrfrz, zcrfrz1, zcrfrz2, zeps, zkcac, & + zkphi1, zkphi2, zkphi3, zmi0, zmimax, zmsmin, & + zn0s0, zn0s1, zn0s2, znimax_thom, zqmin, & + zrho0, zthet, zthn, ztmix, ztrfrz, & + zvz0i, icesedi_exp, zams, & + dist_cldtop_ref, reduce_dep_ref, & + tmin_iceautoconv, zceff_fac, zceff_min, & + mma_driver(10), mmb_driver(10), v_sedi_rain_min, v_sedi_snow_min, v_sedi_graupel_min, & + r_v , & !> gas constant for water vapour + lh_v, & !! latent heat of vapourization + lh_s, & !! latent heat of sublimation + cpdr, & !! (spec. heat of dry air at constant press)^-1 + cvdr, & !! (spec. heat of dry air at const vol)^-1 + b3, & !! melting temperature of ice/snow + t0 !! melting temperature of ice/snow + + ! gscp_data + params%ccsrim = ccsrim + params%ccsagg = ccsagg + params%ccsdep = ccsdep + params%ccsvel = ccsvel + params%ccsvxp = ccsvxp + params%ccslam = ccslam + params%ccslxp = ccslxp + params%ccsaxp = ccsaxp + params%ccsdxp = ccsdxp + params%ccshi1 = ccshi1 + params%ccdvtp = ccdvtp + params%ccidep = ccidep + params%ccswxp = ccswxp + params%zconst = zconst + params%zcev = zcev + params%zbev = zbev + params%zcevxp = zcevxp + params%zbevxp = zbevxp + params%zvzxp = zvzxp + params%zvz0r = zvz0r + params%v0snow = v0snow + params%x13o8 = x13o8 + params%x1o2 = x1o2 + params%x27o16 = x27o16 + params%x3o4 = x3o4 + params%x7o4 = x7o4 + params%x7o8 = x7o8 + params%zbvi = zbvi + params%zcac = zcac + params%zccau = zccau + params%zciau = zciau + params%zcicri = zcicri + params%zcrcri = zcrcri + params%zcrfrz = zcrfrz + params%zcrfrz1 = zcrfrz1 + params%zcrfrz2 = zcrfrz2 + params%zeps = zeps + params%zkcac = zkcac + params%zkphi1 = zkphi1 + params%zkphi2 = zkphi2 + params%zkphi3 = zkphi3 + params%zmi0 = zmi0 + params%zmimax = zmimax + params%zmsmin = zmsmin + params%zn0s0 = zn0s0 + params%zn0s1 = zn0s1 + params%zn0s2 = zn0s2 + params%znimax_thom = znimax_thom + params%zqmin = zqmin + params%zrho0 = zrho0 + params%zthet = zthet + params%zthn = zthn + params%ztmix = ztmix + params%ztrfrz = ztrfrz + params%zvz0i = zvz0i + params%icesedi_exp = icesedi_exp + params%zams = zams + params%iautocon = iautocon + params%isnow_n0temp = isnow_n0temp + params%dist_cldtop_ref = dist_cldtop_ref + params%reduce_dep_ref = reduce_dep_ref + params%tmin_iceautoconv = tmin_iceautoconv + params%zceff_fac = zceff_fac + params%zceff_min = zceff_min + params%v_sedi_rain_min = v_sedi_rain_min + params%v_sedi_snow_min = v_sedi_snow_min + params%v_sedi_graupel_min = v_sedi_graupel_min + + ! mo_physical constants + params%r_v = r_v + params%lh_v = lh_v + params%lh_s = lh_s + params%cpdr = cpdr + params%cvdr = cvdr + params%b3 = b3 + params%t0 = t0 + + + + mma = mma_driver + mmb = mmb_driver + + !$ACC ENTER DATA COPYIN(mma, mmb) + + END SUBROUTINE graupel_init + + + SUBROUTINE graupel_run( & + nvec,ke, & !> array dimensions + ivstart,ivend, kstart, & !! optional start/end indicies + idbg, & !! optional debug level + zdt, dz, & !! numerics parameters + t,p,rho,qv,qc,qi,qr,qs,qg,qnc, & !! prognostic variables + qi0,qc0, & !! cloud ice/water threshold for autoconversion + & b1, & + & b2w, & + & b4w, & + prr_gsp,prs_gsp,pri_gsp,prg_gsp, & !! surface precipitation rates + qrsflux, & ! total precipitation flux + l_cv, & + ithermo_water, & ! water thermodynamics + ldass_lhn, & + ldiag_ttend, ldiag_qtend , & + ddt_tend_t , ddt_tend_qv , & + ddt_tend_qc , ddt_tend_qi , & !> ddt_tend_xx are tendencies + ddt_tend_qr , ddt_tend_qs)!! necessary for dynamics + + INTEGER, INTENT(IN) :: nvec , & !> number of horizontal points + ke !! number of grid points in vertical direction + + INTEGER, INTENT(IN) :: ivstart , & !> optional start index for horizontal direction + ivend , & !! optional end index for horizontal direction + kstart , & !! optional start index for the vertical index + idbg !! optional debug level + + REAL(KIND=wp), INTENT(IN) :: zdt , & !> time step for integration of microphysics ( s ) + qi0,qc0,& !> cloud ice/water threshold for autoconversion + b1,b2w,b4w + + REAL(KIND=wp), DIMENSION(:,:), INTENT(IN) :: dz , & !> layer thickness of full levels ( m ) + rho , & !! density of moist air (kg/m3) + p !! pressure ( Pa ) + + LOGICAL, INTENT(IN):: l_cv, & !! if true, cv is used instead of cp + ldass_lhn + + INTEGER, INTENT(IN):: ithermo_water !! water thermodynamics + + LOGICAL, INTENT(IN):: ldiag_ttend, & ! if true, temperature tendency shall be diagnosed + ldiag_qtend ! if true, moisture tendencies shall be diagnosed + + REAL(KIND=wp), DIMENSION(:,:), INTENT(INOUT) :: t , & !> temperature ( K ) + qv , & !! specific water vapor content (kg/kg) + qc , & !! specific cloud water content (kg/kg) + qi , & !! specific cloud ice content (kg/kg) + qr , & !! specific rain content (kg/kg) + qs , & !! specific snow content (kg/kg) + qg !! specific graupel content (kg/kg) + + REAL(KIND=wp), INTENT(INOUT) :: qrsflux(:,:) ! total precipitation flux (nudg) + + REAL(KIND=wp), DIMENSION(:), INTENT(INOUT) :: prr_gsp, & !> precipitation rate of rain, grid-scale (kg/(m2*s)) + prs_gsp, & !! precipitation rate of snow, grid-scale (kg/(m2*s)) + prg_gsp, & !! precipitation rate of graupel, grid-scale (kg/(m2*s)) + qnc !! cloud number concentration + + REAL(KIND=wp), DIMENSION(:), INTENT(INOUT):: pri_gsp !! precipitation rate of ice, grid-scale (kg/(m2*s)) + + REAL(KIND=wp), DIMENSION(:,:), INTENT(OUT):: ddt_tend_t , & !> tendency T ( 1/s ) + ddt_tend_qv , & !! tendency qv ( 1/s ) + ddt_tend_qc , & !! tendency qc ( 1/s ) + ddt_tend_qi , & !! tendency qi ( 1/s ) + ddt_tend_qr , & !! tendency qr ( 1/s ) + ddt_tend_qs !! tendency qs ( 1/s ) + + CALL graupel ( & + & nvec =nvec , & !> in: actual array size + & ke =ke , & !< in: actual array size + & ivstart=ivstart , & !< in: start index of calculation + & ivend =ivend , & !< in: end index of calculation + & kstart =kstart , & !< in: vertical start index + & zdt =zdt , & !< in: timestep + & qi0 =qi0 , & + & qc0 =qc0 , & + & b1 = b1, & + & b2w = b2w, & + & b4w = b4w, & + & dz =dz , & !< in: vertical layer thickness + & t =t , & !< in: temp,tracer,... + & p =p , & !< in: full level pres + & rho =rho , & !< in: density + & qv =qv , & !< in: spec. humidity + & qc =qc , & !< in: cloud water + & qi =qi , & !< in: cloud ice + & qr =qr , & !< in: rain water + & qs =qs , & !< in: snow + & qg =qg , & !< in: graupel + & qnc = qnc , & !< cloud number concentration + & prr_gsp=prr_gsp , & !< out: precipitation rate of rain + & prs_gsp=prs_gsp , & !< out: precipitation rate of snow + & pri_gsp=pri_gsp , & !< out: precipitation rate of cloud ice + & prg_gsp=prg_gsp , & !< out: precipitation rate of graupel + & qrsflux= qrsflux , & !< out: precipitation flux + & ldiag_ttend = ldiag_ttend , & !< in: if temp. tendency shall be diagnosed + & ldiag_qtend = ldiag_qtend , & !< in: if moisture tendencies shall be diagnosed + & ddt_tend_t = ddt_tend_t , & !< out: tendency temperature + & ddt_tend_qv = ddt_tend_qv , & !< out: tendency QV + & ddt_tend_qc = ddt_tend_qc , & !< out: tendency QC + & ddt_tend_qi = ddt_tend_qi , & !< out: tendency QI + & ddt_tend_qr = ddt_tend_qr , & !< out: tendency QR + & ddt_tend_qs = ddt_tend_qs , & !< out: tendency QS + & idbg=idbg , & + & l_cv=l_cv , & + & ldass_lhn = ldass_lhn , & + & ithermo_water=ithermo_water, &!< in: latent heat choice + & ccsrim = params%ccsrim, & + & ccsagg = params%ccsagg, & + & ccsdep = params%ccsdep, & + & ccsvel = params%ccsvel, & + & ccsvxp = params%ccsvxp, & + & ccslam = params%ccslam, & + & ccslxp = params%ccslxp, & + & ccsaxp = params%ccsaxp, & + & ccsdxp = params%ccsdxp, & + & ccshi1 = params%ccshi1, & + & ccdvtp = params%ccdvtp, & + & ccidep = params%ccidep, & + & ccswxp = params%ccswxp, & + & zconst = params%zconst, & + & zcev = params%zcev, & + & zbev = params%zbev, & + & zcevxp = params%zcevxp, & + & zbevxp = params%zbevxp, & + & zvzxp = params%zvzxp, & + & zvz0r = params%zvz0r, & + & v0snow = params%v0snow, & + & x13o8 = params%x13o8, & + & x1o2 = params%x1o2, & + & x27o16 = params%x27o16, & + & x3o4 = params%x3o4, & + & x7o4 = params%x7o4, & + & x7o8 = params%x7o8, & + & zbvi = params%zbvi, & + & zcac = params%zcac, & + & zccau = params%zccau, & + & zciau = params%zciau, & + & zcicri = params%zcicri, & + & zcrcri = params%zcrcri, & + & zcrfrz = params%zcrfrz, & + & zcrfrz1 = params%zcrfrz1, & + & zcrfrz2 = params%zcrfrz2, & + & zeps = params%zeps, & + & zkcac = params%zkcac, & + & zkphi1 = params%zkphi1, & + & zkphi2 = params%zkphi2, & + & zkphi3 = params%zkphi3, & + & zmi0 = params%zmi0, & + & zmimax = params%zmimax, & + & zmsmin = params%zmsmin, & + & zn0s0 = params%zn0s0, & + & zn0s1 = params%zn0s1, & + & zn0s2 = params%zn0s2, & + & znimax_thom = params%znimax_thom, & + & zqmin = params%zqmin, & + & zrho0 = params%zrho0, & + & zthet = params%zthet, & + & zthn = params%zthn, & + & ztmix = params%ztmix, & + & ztrfrz = params%ztrfrz, & + & zvz0i = params%zvz0i, & + & icesedi_exp = params%icesedi_exp, & + & zams = params%zams, & + & iautocon = params%iautocon, & + & isnow_n0temp = params%isnow_n0temp, & + & dist_cldtop_ref = params%dist_cldtop_ref, & + & reduce_dep_ref = params%reduce_dep_ref, & + & tmin_iceautoconv = params%tmin_iceautoconv, & + & zceff_fac = params%zceff_fac, & + & zceff_min = params%zceff_min, & + & mma = mma, & + & mmb = mmb, & + & v_sedi_rain_min = params%v_sedi_rain_min, & + & v_sedi_snow_min = params%v_sedi_snow_min, & + & v_sedi_graupel_min = params%v_sedi_graupel_min, & + & r_v = params%r_v, & + & lh_v = params%lh_v, & + & lh_s = params%lh_s, & + & cpdr = params%cpdr, & + & cvdr = params%cvdr, & + & b3 = params%b3, & + t0 = params%t0) + + END SUBROUTINE graupel_run + +END MODULE mo_graupel_granule + diff --git a/tools/tests/f2ser/test_parsing.py b/tools/tests/f2ser/test_parsing.py index 177fc496ee..b0346e72f1 100644 --- a/tools/tests/f2ser/test_parsing.py +++ b/tools/tests/f2ser/test_parsing.py @@ -59,3 +59,16 @@ def test_granule_parsing_no_intent(samples_path): parser = GranuleParser(samples_path / "subroutine_example.f90", []) with pytest.raises(ParsingError): parser() + + +def test_multiline_declaration_parsing(samples_path): + parser = GranuleParser(samples_path / "multiline_example.f90", []) + parsed_granule = parser() + subroutines = parsed_granule.subroutines + assert list(subroutines) == ["graupel_init", "graupel_run"] + assert subroutines["graupel_init"]["in"]["codegen_ctx"] == CodegenContext( + first_declaration_ln=121, last_declaration_ln=145, end_subroutine_ln=231 + ) + assert subroutines["graupel_run"]["in"]["codegen_ctx"] == CodegenContext( + first_declaration_ln=254, last_declaration_ln=301, end_subroutine_ln=419 + ) From 1bd6755f48e74cf4cd0a31715184583ebf0840da Mon Sep 17 00:00:00 2001 From: samkellerhals Date: Tue, 11 Jul 2023 14:49:06 +0200 Subject: [PATCH 3/4] Reformat --- .../atm_dyn_iconam/apply_diffusion_to_vn.py | 1 - ...ute_horizontal_gradients_for_turbulance.py | 1 - ...fusion_nabla_of_theta_over_steep_points.py | 1 - tools/src/icon4pytools/common/logger.py | 4 +- tools/src/icon4pytools/f2ser/cli.py | 8 ++- tools/src/icon4pytools/f2ser/deserialise.py | 12 +++- tools/src/icon4pytools/f2ser/parse.py | 61 +++++++++++++------ .../icon4pygen/bindings/codegen/cpp.py | 16 +++-- .../icon4pygen/bindings/codegen/f90.py | 50 +++++++++++---- .../bindings/codegen/render/field.py | 10 ++- .../bindings/codegen/render/offset.py | 9 ++- .../icon4pygen/bindings/entities.py | 14 ++++- .../icon4pygen/bindings/locations.py | 4 +- .../icon4pygen/bindings/workflow.py | 4 +- tools/src/icon4pytools/icon4pygen/cli.py | 4 +- tools/src/icon4pytools/icon4pygen/metadata.py | 16 +++-- tools/src/icon4pytools/liskov/cli.py | 8 ++- .../liskov/codegen/integration/deserialise.py | 47 ++++++++++---- .../liskov/codegen/integration/template.py | 9 ++- .../codegen/serialisation/deserialise.py | 4 +- .../liskov/codegen/serialisation/generate.py | 4 +- .../liskov/codegen/serialisation/template.py | 8 ++- .../liskov/codegen/shared/deserialise.py | 4 +- .../src/icon4pytools/liskov/external/gt4py.py | 9 ++- .../src/icon4pytools/liskov/parsing/parse.py | 8 ++- .../icon4pytools/liskov/parsing/validation.py | 37 ++++++++--- .../liskov/pipeline/collection.py | 12 +++- tools/tests/f2ser/test_f2ser_codegen.py | 8 ++- .../tests/f2ser/test_granule_deserialiser.py | 6 +- tools/tests/f2ser/test_parsing.py | 4 +- tools/tests/icon4pygen/test_backend.py | 4 +- .../tests/icon4pygen/test_field_rendering.py | 8 ++- .../liskov/test_directives_deserialiser.py | 35 ++++++++--- tools/tests/liskov/test_external.py | 5 +- tools/tests/liskov/test_generation.py | 24 ++++++-- .../liskov/test_serialisation_deserialiser.py | 4 +- tools/tests/liskov/test_validation.py | 15 ++++- tools/tests/liskov/test_writer.py | 4 +- 38 files changed, 360 insertions(+), 122 deletions(-) diff --git a/atm_dyn_iconam/src/icon4py/atm_dyn_iconam/apply_diffusion_to_vn.py b/atm_dyn_iconam/src/icon4py/atm_dyn_iconam/apply_diffusion_to_vn.py index 64f79b4d83..b36f742d03 100644 --- a/atm_dyn_iconam/src/icon4py/atm_dyn_iconam/apply_diffusion_to_vn.py +++ b/atm_dyn_iconam/src/icon4py/atm_dyn_iconam/apply_diffusion_to_vn.py @@ -47,7 +47,6 @@ def _apply_diffusion_to_vn( start_2nd_nudge_line_idx_e: int32, limited_area: bool, ) -> Field[[EdgeDim, KDim], float]: - z_nabla4_e2 = _calculate_nabla4( u_vert, v_vert, diff --git a/atm_dyn_iconam/src/icon4py/atm_dyn_iconam/apply_diffusion_to_w_and_compute_horizontal_gradients_for_turbulance.py b/atm_dyn_iconam/src/icon4py/atm_dyn_iconam/apply_diffusion_to_w_and_compute_horizontal_gradients_for_turbulance.py index 4019b4a083..085b946137 100644 --- a/atm_dyn_iconam/src/icon4py/atm_dyn_iconam/apply_diffusion_to_w_and_compute_horizontal_gradients_for_turbulance.py +++ b/atm_dyn_iconam/src/icon4py/atm_dyn_iconam/apply_diffusion_to_w_and_compute_horizontal_gradients_for_turbulance.py @@ -46,7 +46,6 @@ def _apply_diffusion_to_w_and_compute_horizontal_gradients_for_turbulance( Field[[CellDim, KDim], float], Field[[CellDim, KDim], float], ]: - vert_idx = broadcast(vert_idx, (CellDim, KDim)) dwdx, dwdy = where( diff --git a/atm_dyn_iconam/src/icon4py/atm_dyn_iconam/truly_horizontal_diffusion_nabla_of_theta_over_steep_points.py b/atm_dyn_iconam/src/icon4py/atm_dyn_iconam/truly_horizontal_diffusion_nabla_of_theta_over_steep_points.py index ab0574e40e..724d14b676 100644 --- a/atm_dyn_iconam/src/icon4py/atm_dyn_iconam/truly_horizontal_diffusion_nabla_of_theta_over_steep_points.py +++ b/atm_dyn_iconam/src/icon4py/atm_dyn_iconam/truly_horizontal_diffusion_nabla_of_theta_over_steep_points.py @@ -29,7 +29,6 @@ def _truly_horizontal_diffusion_nabla_of_theta_over_steep_points( theta_v: Field[[CellDim, KDim], float], z_temp: Field[[CellDim, KDim], float], ) -> Field[[CellDim, KDim], float]: - theta_v_0 = theta_v(as_offset(Koff, zd_vertoffset(C2CEC[0]))) theta_v_1 = theta_v(as_offset(Koff, zd_vertoffset(C2CEC[1]))) theta_v_2 = theta_v(as_offset(Koff, zd_vertoffset(C2CEC[2]))) diff --git a/tools/src/icon4pytools/common/logger.py b/tools/src/icon4pytools/common/logger.py index d7f7341fed..636a582d49 100644 --- a/tools/src/icon4pytools/common/logger.py +++ b/tools/src/icon4pytools/common/logger.py @@ -18,7 +18,9 @@ def setup_logger(name: str, log_level: int = logging.INFO) -> logging.Logger: """Set up a logger with a given name and log level.""" logger = logging.getLogger(name) logger.setLevel(log_level) - formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") + formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) stream_handler = logging.StreamHandler() stream_handler.setFormatter(formatter) logger.addHandler(stream_handler) diff --git a/tools/src/icon4pytools/f2ser/cli.py b/tools/src/icon4pytools/f2ser/cli.py index 290d8a54f9..a122a8d893 100644 --- a/tools/src/icon4pytools/f2ser/cli.py +++ b/tools/src/icon4pytools/f2ser/cli.py @@ -18,14 +18,18 @@ from icon4pytools.f2ser.deserialise import ParsedGranuleDeserialiser from icon4pytools.f2ser.parse import GranuleParser -from icon4pytools.liskov.codegen.serialisation.generate import SerialisationCodeGenerator +from icon4pytools.liskov.codegen.serialisation.generate import ( + SerialisationCodeGenerator, +) from icon4pytools.liskov.codegen.shared.write import CodegenWriter @click.command("icon_f2ser") @click.argument( "granule_path", - type=click.Path(exists=True, dir_okay=False, resolve_path=True, path_type=pathlib.Path), + type=click.Path( + exists=True, dir_okay=False, resolve_path=True, path_type=pathlib.Path + ), ) @click.argument( "output_filepath", diff --git a/tools/src/icon4pytools/f2ser/deserialise.py b/tools/src/icon4pytools/f2ser/deserialise.py index 883daddc94..e0082eed66 100644 --- a/tools/src/icon4pytools/f2ser/deserialise.py +++ b/tools/src/icon4pytools/f2ser/deserialise.py @@ -55,7 +55,9 @@ def _make_savepoints(self) -> None: for intent, var_dict in intent_dict.items(): self._create_savepoint(subroutine_name, intent, var_dict) - def _create_savepoint(self, subroutine_name: str, intent: str, var_dict: dict) -> None: + def _create_savepoint( + self, subroutine_name: str, intent: str, var_dict: dict + ) -> None: """Create a savepoint for the given variables. Args: @@ -71,7 +73,9 @@ def _create_savepoint(self, subroutine_name: str, intent: str, var_dict: dict) - FieldSerialisationData( variable=var_name, association=self._create_association(var_data, var_name), - decomposed=var_data["decomposed"] if var_data.get("decomposed") else False, + decomposed=var_data["decomposed"] + if var_data.get("decomposed") + else False, dimension=var_data.get("dimension"), typespec=var_data.get("typespec"), typename=var_data.get("typename"), @@ -135,7 +139,9 @@ def _make_init_data(self) -> None: for intent, var_dict in intent_dict.items() if intent == "in" ][0] - startln = self._get_codegen_line(first_intent_in_subroutine["codegen_ctx"], "init") + startln = self._get_codegen_line( + first_intent_in_subroutine["codegen_ctx"], "init" + ) self.data["Init"] = InitData( startln=startln, directory=self.directory, diff --git a/tools/src/icon4pytools/f2ser/parse.py b/tools/src/icon4pytools/f2ser/parse.py index a8c4b36a3e..3d8147da42 100644 --- a/tools/src/icon4pytools/f2ser/parse.py +++ b/tools/src/icon4pytools/f2ser/parse.py @@ -59,7 +59,9 @@ class GranuleParser: parsed_types = parser() """ - def __init__(self, granule: Path, dependencies: Optional[list[Path]] = None) -> None: + def __init__( + self, granule: Path, dependencies: Optional[list[Path]] = None + ) -> None: self.granule_path = granule self.dependencies = dependencies @@ -101,9 +103,12 @@ def _read_code_from_file(self) -> str: def parse_subroutines(self): subroutines = self._extract_subroutines(crack(self.granule_path)) variables_grouped_by_intent = { - name: self._extract_intent_vars(routine) for name, routine in subroutines.items() + name: self._extract_intent_vars(routine) + for name, routine in subroutines.items() } - intrinsic_type_vars, derived_type_vars = self._parse_types(variables_grouped_by_intent) + intrinsic_type_vars, derived_type_vars = self._parse_types( + variables_grouped_by_intent + ) combined_type_vars = self._combine_types(derived_type_vars, intrinsic_type_vars) with_lines = self._update_with_codegen_lines(combined_type_vars) return with_lines @@ -124,7 +129,9 @@ def _extract_subroutines(self, parsed: dict[str, Any]) -> dict[str, Any]: subroutines[name] = elt if len(subroutines) != 2: - raise ParsingError(f"Did not find _init and _run subroutines in {self.granule_path}") + raise ParsingError( + f"Did not find _init and _run subroutines in {self.granule_path}" + ) return subroutines @@ -238,7 +245,9 @@ def _decompose_derived_types(derived_types: dict) -> dict: new_type_name = f"{var_name}_{subtype_name}" new_var_dict = var_dict.copy() new_var_dict.update(subtype_spec) - decomposed_vars[subroutine][intent][new_type_name] = new_var_dict + decomposed_vars[subroutine][intent][ + new_type_name + ] = new_var_dict new_var_dict["ptr_var"] = subtype_name else: decomposed_vars[subroutine][intent][var_name] = var_dict @@ -275,9 +284,9 @@ def _update_with_codegen_lines(self, parsed_types: dict) -> dict: with_lines = deepcopy(parsed_types) for subroutine in with_lines: for intent in with_lines[subroutine]: - with_lines[subroutine][intent]["codegen_ctx"] = self._get_subroutine_lines( - subroutine - ) + with_lines[subroutine][intent][ + "codegen_ctx" + ] = self._get_subroutine_lines(subroutine) return with_lines def _get_subroutine_lines(self, subroutine_name: str) -> CodegenContext: @@ -291,19 +300,31 @@ def _get_subroutine_lines(self, subroutine_name: str) -> CodegenContext: """ code = self._read_code_from_file() - start_subroutine_ln, end_subroutine_ln = self._find_subroutine_lines(code, subroutine_name) + start_subroutine_ln, end_subroutine_ln = self._find_subroutine_lines( + code, subroutine_name + ) - variable_declaration_ln = self._find_variable_declarations(code, start_subroutine_ln, end_subroutine_ln) + variable_declaration_ln = self._find_variable_declarations( + code, start_subroutine_ln, end_subroutine_ln + ) if not variable_declaration_ln: raise ParsingError(f"No variable declarations found in {self.granule_path}") - first_declaration_ln, last_declaration_ln = self._get_variable_declaration_bounds(variable_declaration_ln, - start_subroutine_ln) + ( + first_declaration_ln, + last_declaration_ln, + ) = self._get_variable_declaration_bounds( + variable_declaration_ln, start_subroutine_ln + ) - pre_end_subroutine_ln = end_subroutine_ln - 1 # we want to generate the code before the end of the subroutine + pre_end_subroutine_ln = ( + end_subroutine_ln - 1 + ) # we want to generate the code before the end of the subroutine - return CodegenContext(first_declaration_ln, last_declaration_ln, pre_end_subroutine_ln) + return CodegenContext( + first_declaration_ln, last_declaration_ln, pre_end_subroutine_ln + ) @staticmethod def _find_subroutine_lines(code: str, subroutine_name: str) -> tuple[int]: @@ -327,7 +348,9 @@ def _find_subroutine_lines(code: str, subroutine_name: str) -> tuple[int]: return start_subroutine_ln, end_subroutine_ln @staticmethod - def _find_variable_declarations(code: str, start_subroutine_ln: int, end_subroutine_ln: int) -> list: + def _find_variable_declarations( + code: str, start_subroutine_ln: int, end_subroutine_ln: int + ) -> list: """Finds line numbers of variable declarations within a code block. Args: @@ -347,7 +370,9 @@ def _find_variable_declarations(code: str, start_subroutine_ln: int, end_subrout is_multiline_declaration = False declaration_pattern_lines = [] - for i, line in enumerate(code.splitlines()[start_subroutine_ln:end_subroutine_ln]): + for i, line in enumerate( + code.splitlines()[start_subroutine_ln:end_subroutine_ln] + ): if not is_multiline_declaration: if re.search(declaration_pattern, line): declaration_pattern_lines.append(i) @@ -362,7 +387,9 @@ def _find_variable_declarations(code: str, start_subroutine_ln: int, end_subrout return declaration_pattern_lines @staticmethod - def _get_variable_declaration_bounds(declaration_pattern_lines: list, start_subroutine_ln: int) -> tuple: + def _get_variable_declaration_bounds( + declaration_pattern_lines: list, start_subroutine_ln: int + ) -> tuple: """Returns the line numbers of the bounds for a variable declaration block. Args: diff --git a/tools/src/icon4pytools/icon4pygen/bindings/codegen/cpp.py b/tools/src/icon4pytools/icon4pygen/bindings/codegen/cpp.py index ba6951f572..73c7c4a234 100644 --- a/tools/src/icon4pytools/icon4pygen/bindings/codegen/cpp.py +++ b/tools/src/icon4pytools/icon4pygen/bindings/codegen/cpp.py @@ -27,7 +27,9 @@ run_func_declaration, run_verify_func_declaration, ) -from icon4pytools.icon4pygen.bindings.codegen.render.offset import GpuTriMeshOffsetRenderer +from icon4pytools.icon4pygen.bindings.codegen.render.offset import ( + GpuTriMeshOffsetRenderer, +) from icon4pytools.icon4pygen.bindings.entities import Field, Offset from icon4pytools.icon4pygen.bindings.utils import write_string @@ -663,8 +665,12 @@ def _get_field_data(self) -> tuple: ] sparse_fields = [field for field in self.fields if field.is_sparse()] compound_fields = [field for field in self.fields if field.is_compound()] - sparse_offsets = [offset for offset in self.offsets if not offset.is_compound_location()] - strided_offsets = [offset for offset in self.offsets if offset.is_compound_location()] + sparse_offsets = [ + offset for offset in self.offsets if not offset.is_compound_location() + ] + strided_offsets = [ + offset for offset in self.offsets if offset.is_compound_location() + ] all_fields = self.fields offsets = dict(sparse=sparse_offsets, strided=strided_offsets) @@ -709,7 +715,9 @@ def __post_init__(self) -> None: # type: ignore ), public_utilities=PublicUtilities(fields=fields["output"]), copy_pointers=CopyPointers(fields=self.fields), - private_members=PrivateMembers(fields=self.fields, out_fields=fields["output"]), + private_members=PrivateMembers( + fields=self.fields, out_fields=fields["output"] + ), setup_func=StencilClassSetupFunc( funcname=self.stencil_name, out_fields=fields["output"], diff --git a/tools/src/icon4pytools/icon4pygen/bindings/codegen/f90.py b/tools/src/icon4pytools/icon4pygen/bindings/codegen/f90.py index a3cc0a3e41..a72aaedb67 100644 --- a/tools/src/icon4pytools/icon4pygen/bindings/codegen/f90.py +++ b/tools/src/icon4pytools/icon4pygen/bindings/codegen/f90.py @@ -164,7 +164,9 @@ class F90Generator(TemplatedGenerator): F90Field = as_jinja("{{ name }}{% if suffix %}_{{ suffix }}{% endif %}") - F90OpenACCField = as_jinja("!$ACC {{ name }}{% if suffix %}_{{ suffix }}{% endif %}") + F90OpenACCField = as_jinja( + "!$ACC {{ name }}{% if suffix %}_{{ suffix }}{% endif %}" + ) F90TypedField = as_jinja( "{{ dtype }}, {% if dims %}{{ dims }},{% endif %} target {% if _this_node.optional %} , optional {% endif %}:: {{ name }}{% if suffix %}_{{ suffix }}{% endif %} " @@ -226,10 +228,13 @@ def __post_init__(self) -> None: # type: ignore ) for field in self.all_fields ] + [ - F90TypedField(name=name, dtype="integer(c_int)", dims="value") for name in _DOMAIN_ARGS + F90TypedField(name=name, dtype="integer(c_int)", dims="value") + for name in _DOMAIN_ARGS ] - self.params = F90EntityList(fields=param_fields, line_end=", &", line_end_last=" &") + self.params = F90EntityList( + fields=param_fields, line_end=", &", line_end_last=" &" + ) self.binds = F90EntityList(fields=bind_fields) @@ -273,7 +278,9 @@ def __post_init__(self) -> None: # type: ignore ) for field in self.tol_fields: - param_fields += [F90Field(name=field.name, suffix=s) for s in ["rel_tol", "abs_tol"]] + param_fields += [ + F90Field(name=field.name, suffix=s) for s in ["rel_tol", "abs_tol"] + ] bind_fields += [ F90TypedField( name=field.name, @@ -284,7 +291,9 @@ def __post_init__(self) -> None: # type: ignore for s in ["rel_tol", "abs_tol"] ] - self.params = F90EntityList(fields=param_fields, line_end=", &", line_end_last=" &") + self.params = F90EntityList( + fields=param_fields, line_end=", &", line_end_last=" &" + ) self.binds = F90EntityList(fields=bind_fields) @@ -328,7 +337,9 @@ def __post_init__(self) -> None: # type: ignore for field in self.out_fields ] - self.params = F90EntityList(fields=param_fields, line_end=", &", line_end_last=" &") + self.params = F90EntityList( + fields=param_fields, line_end=", &", line_end_last=" &" + ) self.binds = F90EntityList(fields=bind_fields) @@ -390,7 +401,9 @@ def __post_init__(self) -> None: # type: ignore for field in self.tol_fields ] open_acc_fields = [ - F90OpenACCField(name=field.name) for field in self.all_fields if field.rank() != 0 + F90OpenACCField(name=field.name) + for field in self.all_fields + if field.rank() != 0 ] + [ F90OpenACCField(name=field.name, suffix="before") for field in self.out_fields @@ -420,7 +433,9 @@ def __post_init__(self) -> None: # type: ignore ] for field in self.tol_fields: - param_fields += [F90Field(name=field.name, suffix=s) for s in ["rel_tol", "abs_tol"]] + param_fields += [ + F90Field(name=field.name, suffix=s) for s in ["rel_tol", "abs_tol"] + ] bind_fields += [ F90TypedField( name=field.name, @@ -432,18 +447,25 @@ def __post_init__(self) -> None: # type: ignore for s in ["rel_tol", "abs_tol"] ] run_ver_param_fields += [ - F90Field(name=field.name, suffix=s) for s in ["rel_err_tol", "abs_err_tol"] + F90Field(name=field.name, suffix=s) + for s in ["rel_err_tol", "abs_err_tol"] ] - self.params = F90EntityList(fields=param_fields, line_end=", &", line_end_last=" &") + self.params = F90EntityList( + fields=param_fields, line_end=", &", line_end_last=" &" + ) self.binds = F90EntityList(fields=bind_fields) self.tol_decls = F90EntityList(fields=tol_fields) self.conditionals = F90EntityList(fields=cond_fields) - self.openacc = F90EntityList(fields=open_acc_fields, line_end=", &", line_end_last=" &") + self.openacc = F90EntityList( + fields=open_acc_fields, line_end=", &", line_end_last=" &" + ) self.run_ver_params = F90EntityList( fields=run_ver_param_fields, line_end=", &", line_end_last=" &" ) - self.run_params = F90EntityList(fields=run_param_fields, line_end=", &", line_end_last=" &") + self.run_params = F90EntityList( + fields=run_param_fields, line_end=", &", line_end_last=" &" + ) class F90WrapSetupFun(Node): @@ -514,7 +536,9 @@ def __post_init__(self) -> None: # type: ignore ] ] + [F90Field(name=field.name, suffix="kvert_max") for field in self.out_fields] - self.params = F90EntityList(fields=param_fields, line_end=", &", line_end_last=" &") + self.params = F90EntityList( + fields=param_fields, line_end=", &", line_end_last=" &" + ) self.binds = F90EntityList(fields=bind_fields) self.vert_decls = F90EntityList(fields=vert_fields) self.vert_conditionals = F90EntityList(fields=vert_conditionals_fields) diff --git a/tools/src/icon4pytools/icon4pygen/bindings/codegen/render/field.py b/tools/src/icon4pytools/icon4pygen/bindings/codegen/render/field.py index e9106fba0e..2c170f96c7 100644 --- a/tools/src/icon4pytools/icon4pygen/bindings/codegen/render/field.py +++ b/tools/src/icon4pytools/icon4pygen/bindings/codegen/render/field.py @@ -34,7 +34,11 @@ def render_pointer(self) -> str: def render_dim_tags(self) -> str: """Render c++ dimension tags.""" tags = [] - if self.entity.is_dense() or self.entity.is_sparse() or self.entity.is_compound(): + if ( + self.entity.is_dense() + or self.entity.is_sparse() + or self.entity.is_compound() + ): tags.append("unstructured::dim::horizontal") if self.entity.has_vertical_dimension: tags.append("unstructured::dim::vertical") @@ -46,7 +50,9 @@ def render_sid(self) -> str: raise BindingsRenderingException("can not render sid of a scalar") # We want to compute the rank without the sparse dimension, i.e. if a field is horizontal, vertical or both. - dense_rank = self.entity.rank() - int(self.entity.is_sparse() or self.entity.is_compound()) + dense_rank = self.entity.rank() - int( + self.entity.is_sparse() or self.entity.is_compound() + ) if dense_rank == 1: values_str = "1" elif self.entity.is_compound(): diff --git a/tools/src/icon4pytools/icon4pygen/bindings/codegen/render/offset.py b/tools/src/icon4pytools/icon4pygen/bindings/codegen/render/offset.py index aa40dcf2ad..c0a1d17811 100644 --- a/tools/src/icon4pytools/icon4pygen/bindings/codegen/render/offset.py +++ b/tools/src/icon4pytools/icon4pygen/bindings/codegen/render/offset.py @@ -77,7 +77,9 @@ def __init__(self, offsets: Sequence[OffsetEntity]): def make_table_vars(self) -> list[str]: if not self.has_offsets: return [] - unique_offsets = sorted({self._make_table_var(offset) for offset in self.offsets}) + unique_offsets = sorted( + {self._make_table_var(offset) for offset in self.offsets} + ) return list(unique_offsets) def make_neighbor_tables(self) -> list[str]: @@ -101,4 +103,7 @@ def _make_table_var(offset: OffsetEntity) -> str: @staticmethod def _make_location_type(offset: OffsetEntity) -> list[str]: - return [f"LocationType::{loc.render_location_type()}" for loc in offset.target[1].chain] + return [ + f"LocationType::{loc.render_location_type()}" + for loc in offset.target[1].chain + ] diff --git a/tools/src/icon4pytools/icon4pygen/bindings/entities.py b/tools/src/icon4pytools/icon4pygen/bindings/entities.py index cad08c7c0d..44dc97c1f1 100644 --- a/tools/src/icon4pytools/icon4pygen/bindings/entities.py +++ b/tools/src/icon4pytools/icon4pygen/bindings/entities.py @@ -19,7 +19,11 @@ from icon4pytools.icon4pygen.bindings.codegen.render.field import FieldRenderer from icon4pytools.icon4pygen.bindings.codegen.render.offset import OffsetRenderer -from icon4pytools.icon4pygen.bindings.codegen.types import FieldEntity, FieldIntent, OffsetEntity +from icon4pytools.icon4pygen.bindings.codegen.types import ( + FieldEntity, + FieldIntent, + OffsetEntity, +) from icon4pytools.icon4pygen.bindings.exceptions import BindingsTypeConsistencyException from icon4pytools.icon4pygen.bindings.locations import ( BASIC_LOCATIONS, @@ -71,7 +75,9 @@ def _handle_source(chain: str) -> Union[BasicLocation, CompoundLocation]: if source in [str(loc()) for loc in BASIC_LOCATIONS.values()]: return chain_from_str(source)[0] - elif all(char in [str(loc()) for loc in BASIC_LOCATIONS.values()] for char in source): + elif all( + char in [str(loc()) for loc in BASIC_LOCATIONS.values()] for char in source + ): return CompoundLocation(chain_from_str(source)) else: raise BindingsTypeConsistencyException(f"Invalid source {source}") @@ -156,7 +162,9 @@ def _update_horizontal_location(self, field: past.DataSymbol) -> None: if not isinstance(field.type, ts.FieldType): return - maybe_horizontal_dimension = list(filter(lambda dim: dim.value != "K", field.type.dims)) + maybe_horizontal_dimension = list( + filter(lambda dim: dim.value != "K", field.type.dims) + ) # early abort if field is vertical if not len(maybe_horizontal_dimension): diff --git a/tools/src/icon4pytools/icon4pygen/bindings/locations.py b/tools/src/icon4pytools/icon4pygen/bindings/locations.py index 4071da48b6..35a825ddf7 100644 --- a/tools/src/icon4pytools/icon4pygen/bindings/locations.py +++ b/tools/src/icon4pytools/icon4pygen/bindings/locations.py @@ -64,7 +64,9 @@ def __init__(self, chain: list[BasicLocation]) -> None: if is_valid(chain): self.chain = chain else: - raise Exception(f"chain {chain} contains two of the same elements in succession") + raise Exception( + f"chain {chain} contains two of the same elements in succession" + ) def __iter__(self) -> Iterator[BasicLocation]: return iter(self.chain) diff --git a/tools/src/icon4pytools/icon4pygen/bindings/workflow.py b/tools/src/icon4pytools/icon4pygen/bindings/workflow.py index 83a5d1f6ee..108f602fd6 100644 --- a/tools/src/icon4pytools/icon4pygen/bindings/workflow.py +++ b/tools/src/icon4pytools/icon4pygen/bindings/workflow.py @@ -34,7 +34,9 @@ class PyBindGen: Furthermore, we also serialise data to .csv or .vtk files in case of verification failure. """ - def __init__(self, stencil_info: StencilInfo, levels_per_thread: int, block_size: int) -> None: + def __init__( + self, stencil_info: StencilInfo, levels_per_thread: int, block_size: int + ) -> None: self.stencil_name = stencil_info.itir.id self.fields, self.offsets = self._stencil_info_to_binding_type(stencil_info) self.levels_per_thread = levels_per_thread diff --git a/tools/src/icon4pytools/icon4pygen/cli.py b/tools/src/icon4pytools/icon4pygen/cli.py index 66180adf44..0a1f0f0233 100644 --- a/tools/src/icon4pytools/icon4pygen/cli.py +++ b/tools/src/icon4pytools/icon4pygen/cli.py @@ -45,7 +45,9 @@ def shell_complete(self, ctx, param, incomplete): @click.argument("fencil", type=ModuleType()) @click.argument("block_size", type=int, default=128) @click.argument("levels_per_thread", type=int, default=4) -@click.option("--is_global", is_flag=True, type=bool, help="Whether this is a global run.") +@click.option( + "--is_global", is_flag=True, type=bool, help="Whether this is a global run." +) @click.argument( "outpath", type=click.Path(dir_okay=True, resolve_path=True, path_type=pathlib.Path), diff --git a/tools/src/icon4pytools/icon4pygen/metadata.py b/tools/src/icon4pytools/icon4pygen/metadata.py index 1f1b42e4cb..0b4a86e02f 100644 --- a/tools/src/icon4pytools/icon4pygen/metadata.py +++ b/tools/src/icon4pytools/icon4pygen/metadata.py @@ -142,7 +142,9 @@ def get_fvprog(fencil_def: Program | Any) -> Program: return fvprog -def provide_offset(offset: str, is_global: bool = False) -> DummyConnectivity | Dimension: +def provide_offset( + offset: str, is_global: bool = False +) -> DummyConnectivity | Dimension: if offset == Koff.value: assert len(Koff.target) == 1 assert Koff.source == Koff.target[0] @@ -164,7 +166,9 @@ def provide_neighbor_table(chain: str, is_global: bool) -> DummyConnectivity: and pass the tokens after to the algorithm below """ # note: this seems really brittle. maybe agree on a keyword to indicate new sparse fields? - new_sparse_field = any(len(token) > 1 for token in chain.split("2")) and not chain.endswith("O") + new_sparse_field = any( + len(token) > 1 for token in chain.split("2") + ) and not chain.endswith("O") if new_sparse_field: chain = chain.split("2")[1] skip_values = False @@ -195,9 +199,13 @@ def provide_neighbor_table(chain: str, is_global: bool) -> DummyConnectivity: def scan_for_offsets(fvprog: Program) -> list[eve.concepts.SymbolRef]: """Scan PAST node for offsets and return a set of all offsets.""" - all_types = fvprog.past_node.pre_walk_values().if_isinstance(past.Symbol).getattr("type") + all_types = ( + fvprog.past_node.pre_walk_values().if_isinstance(past.Symbol).getattr("type") + ) all_field_types = [ - symbol_type for symbol_type in all_types if isinstance(symbol_type, ts.FieldType) + symbol_type + for symbol_type in all_types + if isinstance(symbol_type, ts.FieldType) ] all_dims = set(i for j in all_field_types for i in j.dims) all_offset_labels = ( diff --git a/tools/src/icon4pytools/liskov/cli.py b/tools/src/icon4pytools/liskov/cli.py index a16e0a428e..806bcdcec8 100644 --- a/tools/src/icon4pytools/liskov/cli.py +++ b/tools/src/icon4pytools/liskov/cli.py @@ -52,7 +52,9 @@ def main(ctx): ) @click.argument( "input_path", - type=click.Path(exists=True, dir_okay=False, resolve_path=True, path_type=pathlib.Path), + type=click.Path( + exists=True, dir_okay=False, resolve_path=True, path_type=pathlib.Path + ), ) @click.argument( "output_path", @@ -82,7 +84,9 @@ def integrate(input_path, output_path, profile, metadatagen): ) @click.argument( "input_path", - type=click.Path(exists=True, dir_okay=False, resolve_path=True, path_type=pathlib.Path), + type=click.Path( + exists=True, dir_okay=False, resolve_path=True, path_type=pathlib.Path + ), ) @click.argument( "output_path", diff --git a/tools/src/icon4pytools/liskov/codegen/integration/deserialise.py b/tools/src/icon4pytools/liskov/codegen/integration/deserialise.py index 1884190504..2f4d6c23da 100644 --- a/tools/src/icon4pytools/liskov/codegen/integration/deserialise.py +++ b/tools/src/icon4pytools/liskov/codegen/integration/deserialise.py @@ -115,7 +115,9 @@ def __call__(self, parsed: ts.ParsedDict) -> CodeGenInput: class EndCreateDataFactory(OptionalMultiUseDataFactory): - directive_cls: Type[ts.ParsedDirective] = icon4pytools.liskov.parsing.parse.EndCreate + directive_cls: Type[ + ts.ParsedDirective + ] = icon4pytools.liskov.parsing.parse.EndCreate dtype: Type[EndCreateData] = EndCreateData @@ -130,16 +132,19 @@ class EndIfDataFactory(OptionalMultiUseDataFactory): class EndProfileDataFactory(OptionalMultiUseDataFactory): - directive_cls: Type[ts.ParsedDirective] = icon4pytools.liskov.parsing.parse.EndProfile + directive_cls: Type[ + ts.ParsedDirective + ] = icon4pytools.liskov.parsing.parse.EndProfile dtype: Type[EndProfileData] = EndProfileData class StartCreateDataFactory(DataFactoryBase): - directive_cls: Type[ts.ParsedDirective] = icon4pytools.liskov.parsing.parse.StartCreate + directive_cls: Type[ + ts.ParsedDirective + ] = icon4pytools.liskov.parsing.parse.StartCreate dtype: Type[StartCreateData] = StartCreateData def __call__(self, parsed: ts.ParsedDict) -> list[StartCreateData]: - deserialised = [] extracted = extract_directive(parsed["directives"], self.directive_cls) @@ -147,7 +152,6 @@ def __call__(self, parsed: ts.ParsedDict) -> list[StartCreateData]: return UnusedDirective for i, directive in enumerate(extracted): - named_args = parsed["content"]["StartCreate"][i] extra_fields = None @@ -177,7 +181,9 @@ def __call__(self, parsed: ts.ParsedDict) -> list[DeclareData]: extracted = extract_directive(parsed["directives"], self.directive_cls) for i, directive in enumerate(extracted): named_args = parsed["content"]["Declare"][i] - ident_type = pop_item_from_dict(named_args, "type", DEFAULT_DECLARE_IDENT_TYPE) + ident_type = pop_item_from_dict( + named_args, "type", DEFAULT_DECLARE_IDENT_TYPE + ) suffix = pop_item_from_dict(named_args, "suffix", DEFAULT_DECLARE_SUFFIX) deserialised.append( self.dtype( @@ -191,7 +197,9 @@ def __call__(self, parsed: ts.ParsedDict) -> list[DeclareData]: class StartProfileDataFactory(DataFactoryBase): - directive_cls: Type[ts.ParsedDirective] = icon4pytools.liskov.parsing.parse.StartProfile + directive_cls: Type[ + ts.ParsedDirective + ] = icon4pytools.liskov.parsing.parse.StartProfile dtype: Type[StartProfileData] = StartProfileData def __call__(self, parsed: ts.ParsedDict) -> list[StartProfileData]: @@ -200,12 +208,16 @@ def __call__(self, parsed: ts.ParsedDict) -> list[StartProfileData]: for i, directive in enumerate(extracted): named_args = parsed["content"]["StartProfile"][i] stencil_name = _extract_stencil_name(named_args, directive) - deserialised.append(self.dtype(name=stencil_name, startln=directive.startln)) + deserialised.append( + self.dtype(name=stencil_name, startln=directive.startln) + ) return deserialised class EndStencilDataFactory(DataFactoryBase): - directive_cls: Type[ts.ParsedDirective] = icon4pytools.liskov.parsing.parse.EndStencil + directive_cls: Type[ + ts.ParsedDirective + ] = icon4pytools.liskov.parsing.parse.EndStencil dtype: Type[EndStencilData] = EndStencilData def __call__(self, parsed: ts.ParsedDict) -> list[EndStencilData]: @@ -230,7 +242,9 @@ def __call__(self, parsed: ts.ParsedDict) -> list[EndStencilData]: class StartStencilDataFactory(DataFactoryBase): - directive_cls: Type[ts.ParsedDirective] = icon4pytools.liskov.parsing.parse.StartStencil + directive_cls: Type[ + ts.ParsedDirective + ] = icon4pytools.liskov.parsing.parse.StartStencil dtype: Type[StartStencilData] = StartStencilData def __call__(self, parsed: ts.ParsedDict) -> list[StartStencilData]: @@ -244,13 +258,20 @@ def __call__(self, parsed: ts.ParsedDict) -> list[StartStencilData]: """ deserialised = [] field_dimensions = flatten_list_of_dicts( - [DeclareDataFactory.get_field_dimensions(dim) for dim in parsed["content"]["Declare"]] + [ + DeclareDataFactory.get_field_dimensions(dim) + for dim in parsed["content"]["Declare"] + ] ) directives = extract_directive(parsed["directives"], self.directive_cls) for i, directive in enumerate(directives): named_args = parsed["content"]["StartStencil"][i] - acc_present = string_to_bool(pop_item_from_dict(named_args, "accpresent", "true")) - mergecopy = string_to_bool(pop_item_from_dict(named_args, "mergecopy", "false")) + acc_present = string_to_bool( + pop_item_from_dict(named_args, "accpresent", "true") + ) + mergecopy = string_to_bool( + pop_item_from_dict(named_args, "mergecopy", "false") + ) copies = string_to_bool(pop_item_from_dict(named_args, "copies", "true")) stencil_name = _extract_stencil_name(named_args, directive) bounds = self._make_bounds(named_args) diff --git a/tools/src/icon4pytools/liskov/codegen/integration/template.py b/tools/src/icon4pytools/liskov/codegen/integration/template.py index dba9543864..9b313d21d3 100644 --- a/tools/src/icon4pytools/liskov/codegen/integration/template.py +++ b/tools/src/icon4pytools/liskov/codegen/integration/template.py @@ -20,7 +20,10 @@ from gt4py.eve.codegen import TemplatedGenerator from icon4pytools.liskov.codegen.integration.exceptions import UndeclaredFieldError -from icon4pytools.liskov.codegen.integration.interface import DeclareData, StartStencilData +from icon4pytools.liskov.codegen.integration.interface import ( + DeclareData, + StartStencilData, +) from icon4pytools.liskov.external.metadata import CodeMetadata @@ -238,7 +241,9 @@ class StartStencilStatement(eve.Node): def __post_init__(self) -> None: # type: ignore all_fields = [Field(**asdict(f)) for f in self.stencil_data.fields] - self.copy_declarations = [self.make_copy_declaration(f) for f in all_fields if f.out] + self.copy_declarations = [ + self.make_copy_declaration(f) for f in all_fields if f.out + ] self.acc_present = "PRESENT" if self.stencil_data.acc_present else "NONE" @staticmethod diff --git a/tools/src/icon4pytools/liskov/codegen/serialisation/deserialise.py b/tools/src/icon4pytools/liskov/codegen/serialisation/deserialise.py index 9db5533201..d64adfc707 100644 --- a/tools/src/icon4pytools/liskov/codegen/serialisation/deserialise.py +++ b/tools/src/icon4pytools/liskov/codegen/serialisation/deserialise.py @@ -178,7 +178,9 @@ def _get_timestep_variables(stencil_name: str) -> dict: if "mo_icon_interpolation_scalar" in stencil_name: timestep_variables["jstep"] = "jstep_ptr" - timestep_variables["mo_icon_interpolation_ctr"] = "mo_icon_interpolation_ctr" + timestep_variables[ + "mo_icon_interpolation_ctr" + ] = "mo_icon_interpolation_ctr" if "mo_advection_traj" in stencil_name: timestep_variables["jstep"] = "jstep_ptr" diff --git a/tools/src/icon4pytools/liskov/codegen/serialisation/generate.py b/tools/src/icon4pytools/liskov/codegen/serialisation/generate.py index fed5b8f735..0b2c343596 100644 --- a/tools/src/icon4pytools/liskov/codegen/serialisation/generate.py +++ b/tools/src/icon4pytools/liskov/codegen/serialisation/generate.py @@ -13,7 +13,9 @@ from typing import Any from icon4pytools.common.logger import setup_logger -from icon4pytools.liskov.codegen.serialisation.interface import SerialisationCodeInterface +from icon4pytools.liskov.codegen.serialisation.interface import ( + SerialisationCodeInterface, +) from icon4pytools.liskov.codegen.serialisation.template import ( ImportStatement, ImportStatementGenerator, diff --git a/tools/src/icon4pytools/liskov/codegen/serialisation/template.py b/tools/src/icon4pytools/liskov/codegen/serialisation/template.py index 8c9f8d5325..fe9d630ebd 100644 --- a/tools/src/icon4pytools/liskov/codegen/serialisation/template.py +++ b/tools/src/icon4pytools/liskov/codegen/serialisation/template.py @@ -49,11 +49,15 @@ class SavepointStatement(eve.Node): multinode: bool standard_fields: StandardFields = eve.datamodels.field(init=False) decomposed_fields: DecomposedFields = eve.datamodels.field(init=False) - decomposed_field_declarations: DecomposedFieldDeclarations = eve.datamodels.field(init=False) + decomposed_field_declarations: DecomposedFieldDeclarations = eve.datamodels.field( + init=False + ) def __post_init__(self): self.standard_fields = StandardFields( - fields=[Field(**asdict(f)) for f in self.savepoint.fields if not f.decomposed] + fields=[ + Field(**asdict(f)) for f in self.savepoint.fields if not f.decomposed + ] ) self.decomposed_fields = DecomposedFields( fields=[Field(**asdict(f)) for f in self.savepoint.fields if f.decomposed] diff --git a/tools/src/icon4pytools/liskov/codegen/shared/deserialise.py b/tools/src/icon4pytools/liskov/codegen/shared/deserialise.py index e71ef9fe3f..302b0d762f 100644 --- a/tools/src/icon4pytools/liskov/codegen/shared/deserialise.py +++ b/tools/src/icon4pytools/liskov/codegen/shared/deserialise.py @@ -16,7 +16,9 @@ import icon4pytools.liskov.parsing.types as ts from icon4pytools.common.logger import setup_logger from icon4pytools.liskov.codegen.integration.interface import IntegrationCodeInterface -from icon4pytools.liskov.codegen.serialisation.interface import SerialisationCodeInterface +from icon4pytools.liskov.codegen.serialisation.interface import ( + SerialisationCodeInterface, +) from icon4pytools.liskov.pipeline.definition import Step diff --git a/tools/src/icon4pytools/liskov/external/gt4py.py b/tools/src/icon4pytools/liskov/external/gt4py.py index 4ad14ef9d3..126abe2175 100644 --- a/tools/src/icon4pytools/liskov/external/gt4py.py +++ b/tools/src/icon4pytools/liskov/external/gt4py.py @@ -20,7 +20,10 @@ from icon4pytools.common.logger import setup_logger from icon4pytools.icon4pygen.metadata import get_stencil_info from icon4pytools.liskov.codegen.integration.interface import IntegrationCodeInterface -from icon4pytools.liskov.external.exceptions import IncompatibleFieldError, UnknownStencilError +from icon4pytools.liskov.external.exceptions import ( + IncompatibleFieldError, + UnknownStencilError, +) from icon4pytools.liskov.pipeline.definition import Step @@ -62,7 +65,9 @@ def _collect_icon4py_stencil(self, stencil_name: str) -> Program: err_counter += 1 if err_counter == len(self._STENCIL_PACKAGES): - raise UnknownStencilError(f"Did not find module: {stencil_name} in icon4pytools.") + raise UnknownStencilError( + f"Did not find module: {stencil_name} in icon4pytools." + ) module_members = getmembers(module) found_stencil = [elt for elt in module_members if elt[0] == stencil_name] diff --git a/tools/src/icon4pytools/liskov/parsing/parse.py b/tools/src/icon4pytools/liskov/parsing/parse.py index 02759621a5..984dae6ca0 100644 --- a/tools/src/icon4pytools/liskov/parsing/parse.py +++ b/tools/src/icon4pytools/liskov/parsing/parse.py @@ -81,14 +81,18 @@ def _determine_type( ) return typed - def _preprocess(self, directives: Sequence[ts.ParsedDirective]) -> Sequence[ts.ParsedDirective]: + def _preprocess( + self, directives: Sequence[ts.ParsedDirective] + ) -> Sequence[ts.ParsedDirective]: """Preprocess the directives by removing unnecessary characters and formatting the directive strings.""" return [ d.__class__(self._clean_string(d.string), d.startln, d.endln) # type: ignore for d in directives ] - def _run_validation_passes(self, preprocessed: Sequence[ts.ParsedDirective]) -> None: + def _run_validation_passes( + self, preprocessed: Sequence[ts.ParsedDirective] + ) -> None: """Run validation passes on the directives.""" for validator in VALIDATORS: validator(self.input_filepath).validate(preprocessed) diff --git a/tools/src/icon4pytools/liskov/parsing/validation.py b/tools/src/icon4pytools/liskov/parsing/validation.py index f05a97e720..c27361b8b8 100644 --- a/tools/src/icon4pytools/liskov/parsing/validation.py +++ b/tools/src/icon4pytools/liskov/parsing/validation.py @@ -25,7 +25,10 @@ RequiredDirectivesError, UnbalancedStencilDirectiveError, ) -from icon4pytools.liskov.parsing.utils import print_parsed_directive, remove_directive_types +from icon4pytools.liskov.parsing.utils import ( + print_parsed_directive, + remove_directive_types, +) logger = setup_logger(__name__) @@ -65,12 +68,16 @@ def validate(self, directives: list[ts.ParsedDirective]) -> None: self._validate_outer(d.string, d.pattern, d) self._validate_inner(d.string, d.pattern, d) - def _validate_outer(self, to_validate: str, pattern: str, d: ts.ParsedDirective) -> None: + def _validate_outer( + self, to_validate: str, pattern: str, d: ts.ParsedDirective + ) -> None: regex = f"{pattern}\\((.*)\\)" match = re.fullmatch(regex, to_validate) self.exception_handler.check_for_matches(d, match, regex, self.filepath) - def _validate_inner(self, to_validate: str, pattern: str, d: ts.ParsedDirective) -> None: + def _validate_inner( + self, to_validate: str, pattern: str, d: ts.ParsedDirective + ) -> None: inner = to_validate.replace(f"{pattern}", "")[1:-1].split(";") for arg in inner: match = re.fullmatch(d.regex, arg) @@ -101,7 +108,9 @@ def validate(self, directives: list[ts.ParsedDirective]) -> None: self._validate_required_directives(directives) self._validate_stencil_directives(directives) - def _validate_directive_uniqueness(self, directives: list[ts.ParsedDirective]) -> None: + def _validate_directive_uniqueness( + self, directives: list[ts.ParsedDirective] + ) -> None: """Check that all used directives are unique. Note: Allow repeated START STENCIL, END STENCIL and ENDIF directives. @@ -125,7 +134,9 @@ def _validate_directive_uniqueness(self, directives: list[ts.ParsedDirective]) - f"Error in {self.filepath}.\n Found same directive more than once in the following directives:\n {pretty_printed}" ) - def _validate_required_directives(self, directives: list[ts.ParsedDirective]) -> None: + def _validate_required_directives( + self, directives: list[ts.ParsedDirective] + ) -> None: """Check that all required directives are used at least once.""" expected = [ icon4pytools.liskov.parsing.parse.Declare, @@ -145,9 +156,13 @@ def extract_arg_from_directive(directive: str, arg: str) -> str: if match: return match.group(1) else: - raise ValueError(f"Invalid directive string, could not find '{arg}' parameter.") + raise ValueError( + f"Invalid directive string, could not find '{arg}' parameter." + ) - def _validate_stencil_directives(self, directives: list[ts.ParsedDirective]) -> None: + def _validate_stencil_directives( + self, directives: list[ts.ParsedDirective] + ) -> None: """Validate that the number of start and end stencil directives match in the input `directives`. Also verifies that each unique stencil has a corresponding start and end directive. @@ -171,10 +186,14 @@ def _validate_stencil_directives(self, directives: list[ts.ParsedDirective]) -> for directive in stencil_directives: stencil_name = self.extract_arg_from_directive(directive.string, "name") stencil_counts[stencil_name] = stencil_counts.get(stencil_name, 0) + ( - 1 if isinstance(directive, icon4pytools.liskov.parsing.parse.StartStencil) else -1 + 1 + if isinstance(directive, icon4pytools.liskov.parsing.parse.StartStencil) + else -1 ) - unbalanced_stencils = [stencil for stencil, count in stencil_counts.items() if count != 0] + unbalanced_stencils = [ + stencil for stencil, count in stencil_counts.items() if count != 0 + ] if unbalanced_stencils: raise UnbalancedStencilDirectiveError( f"Error in {self.filepath}. Each unique stencil must have a corresponding START STENCIL and END STENCIL directive." diff --git a/tools/src/icon4pytools/liskov/pipeline/collection.py b/tools/src/icon4pytools/liskov/pipeline/collection.py index 006fa8a0e9..b71c2d2d4f 100644 --- a/tools/src/icon4pytools/liskov/pipeline/collection.py +++ b/tools/src/icon4pytools/liskov/pipeline/collection.py @@ -12,11 +12,17 @@ # SPDX-License-Identifier: GPL-3.0-or-later from pathlib import Path -from icon4pytools.liskov.codegen.integration.deserialise import IntegrationCodeDeserialiser +from icon4pytools.liskov.codegen.integration.deserialise import ( + IntegrationCodeDeserialiser, +) from icon4pytools.liskov.codegen.integration.generate import IntegrationCodeGenerator from icon4pytools.liskov.codegen.integration.interface import IntegrationCodeInterface -from icon4pytools.liskov.codegen.serialisation.deserialise import SerialisationCodeDeserialiser -from icon4pytools.liskov.codegen.serialisation.generate import SerialisationCodeGenerator +from icon4pytools.liskov.codegen.serialisation.deserialise import ( + SerialisationCodeDeserialiser, +) +from icon4pytools.liskov.codegen.serialisation.generate import ( + SerialisationCodeGenerator, +) from icon4pytools.liskov.codegen.shared.write import CodegenWriter from icon4pytools.liskov.external.gt4py import UpdateFieldsWithGt4PyStencils from icon4pytools.liskov.parsing.parse import DirectivesParser diff --git a/tools/tests/f2ser/test_f2ser_codegen.py b/tools/tests/f2ser/test_f2ser_codegen.py index 5ca607eff9..8beedb6432 100644 --- a/tools/tests/f2ser/test_f2ser_codegen.py +++ b/tools/tests/f2ser/test_f2ser_codegen.py @@ -15,7 +15,9 @@ from icon4pytools.f2ser.deserialise import ParsedGranuleDeserialiser from icon4pytools.f2ser.parse import GranuleParser -from icon4pytools.liskov.codegen.serialisation.generate import SerialisationCodeGenerator +from icon4pytools.liskov.codegen.serialisation.generate import ( + SerialisationCodeGenerator, +) from icon4pytools.liskov.codegen.shared.types import GeneratedCode @@ -102,5 +104,7 @@ def test_deserialiser_directives_diffusion_codegen( parsed = GranuleParser(diffusion_granule, diffusion_granule_deps)() interface = ParsedGranuleDeserialiser(parsed)() generated = SerialisationCodeGenerator(interface)() - reference_savepoint = (samples_path / "expected_diffusion_granule_savepoint.f90").read_text() + reference_savepoint = ( + samples_path / "expected_diffusion_granule_savepoint.f90" + ).read_text() assert generated[0].source == reference_savepoint.rstrip() diff --git a/tools/tests/f2ser/test_granule_deserialiser.py b/tools/tests/f2ser/test_granule_deserialiser.py index d0f478e7ce..0fa352b8fa 100644 --- a/tools/tests/f2ser/test_granule_deserialiser.py +++ b/tools/tests/f2ser/test_granule_deserialiser.py @@ -83,7 +83,11 @@ def test_deserialiser_mock(mock_parsed_granule): assert len(interface.Savepoint) == 3 assert all([isinstance(s, SavepointData) for s in interface.Savepoint]) assert all( - [isinstance(f, FieldSerialisationData) for s in interface.Savepoint for f in s.fields] + [ + isinstance(f, FieldSerialisationData) + for s in interface.Savepoint + for f in s.fields + ] ) diff --git a/tools/tests/f2ser/test_parsing.py b/tools/tests/f2ser/test_parsing.py index b0346e72f1..e2a598c07a 100644 --- a/tools/tests/f2ser/test_parsing.py +++ b/tools/tests/f2ser/test_parsing.py @@ -51,7 +51,9 @@ def test_granule_parsing(diffusion_granule, diffusion_granule_deps): def test_granule_parsing_missing_derived_typedef(diffusion_granule, samples_path): dependencies = [samples_path / "subroutine_example.f90"] parser = GranuleParser(diffusion_granule, dependencies) - with pytest.raises(MissingDerivedTypeError, match="Could not find type definition for TYPE"): + with pytest.raises( + MissingDerivedTypeError, match="Could not find type definition for TYPE" + ): parser() diff --git a/tools/tests/icon4pygen/test_backend.py b/tools/tests/icon4pygen/test_backend.py index e712d3ffa6..b0fbfd9b00 100644 --- a/tools/tests/icon4pygen/test_backend.py +++ b/tools/tests/icon4pygen/test_backend.py @@ -28,6 +28,8 @@ ) def test_missing_domain_args(input_params, expected_complement): params = [itir.Sym(id=p) for p in input_params] - domain_boundaries = set(map(lambda s: str(s.id), GTHeader._missing_domain_params(params))) + domain_boundaries = set( + map(lambda s: str(s.id), GTHeader._missing_domain_params(params)) + ) assert len(domain_boundaries) == len(expected_complement) assert domain_boundaries == set(expected_complement) diff --git a/tools/tests/icon4pygen/test_field_rendering.py b/tools/tests/icon4pygen/test_field_rendering.py index 2ace9a6e24..68231f57f4 100644 --- a/tools/tests/icon4pygen/test_field_rendering.py +++ b/tools/tests/icon4pygen/test_field_rendering.py @@ -61,7 +61,9 @@ def identity(field: Field[[EdgeDim, KDim], float]) -> Field[[EdgeDim, KDim], flo return field @program - def identity_prog(field: Field[[EdgeDim, KDim], float], out: Field[[EdgeDim, KDim], float]): + def identity_prog( + field: Field[[EdgeDim, KDim], float], out: Field[[EdgeDim, KDim], float] + ): identity(field, out=out) stencil_info = get_stencil_info(identity_prog) @@ -75,7 +77,9 @@ def identity_prog(field: Field[[EdgeDim, KDim], float], out: Field[[EdgeDim, KDi def test_vertical_sparse_field_sid_rendering(): @field_operator - def reduction(nb_field: Field[[EdgeDim, E2CDim, KDim], float]) -> Field[[EdgeDim, KDim], float]: + def reduction( + nb_field: Field[[EdgeDim, E2CDim, KDim], float] + ) -> Field[[EdgeDim, KDim], float]: return neighbor_sum(nb_field, axis=E2CDim) @program diff --git a/tools/tests/liskov/test_directives_deserialiser.py b/tools/tests/liskov/test_directives_deserialiser.py index 13a235f6f5..b1de1b72ba 100644 --- a/tools/tests/liskov/test_directives_deserialiser.py +++ b/tools/tests/liskov/test_directives_deserialiser.py @@ -85,7 +85,9 @@ ), ], ) -def test_data_factories_no_args(factory_class, directive_type, string, startln, endln, expected): +def test_data_factories_no_args( + factory_class, directive_type, string, startln, endln, expected +): parsed = { "directives": [directive_type(string=string, startln=startln, endln=endln)], "content": {}, @@ -109,7 +111,9 @@ def test_data_factories_no_args(factory_class, directive_type, string, startln, { "directives": [ ts.EndStencil("END STENCIL(name=foo)", 5, 5), - ts.EndStencil("END STENCIL(name=bar; noendif=true; noprofile=true)", 20, 20), + ts.EndStencil( + "END STENCIL(name=bar; noendif=true; noprofile=true)", 20, 20 + ), ], "content": { "EndStencil": [ @@ -123,7 +127,9 @@ def test_data_factories_no_args(factory_class, directive_type, string, startln, EndStencilDataFactory, EndStencilData, { - "directives": [ts.EndStencil("END STENCIL(name=foo; noprofile=true)", 5, 5)], + "directives": [ + ts.EndStencil("END STENCIL(name=foo; noprofile=true)", 5, 5) + ], "content": {"EndStencil": [{"name": "foo"}]}, }, ), @@ -197,7 +203,9 @@ def test_data_factories_with_args(factory, target, mock_data): ), ( { - "directives": [ts.StartCreate("START CREATE(extra_fields=foo,xyz)", 5, 5)], + "directives": [ + ts.StartCreate("START CREATE(extra_fields=foo,xyz)", 5, 5) + ], "content": {"StartCreate": [{"extra_fields": "foo,xyz"}]}, }, ["foo", "xyz"], @@ -231,7 +239,9 @@ def test_start_create_factory(mock_data, extra_fields): ts.EndStencil("END STENCIL(name=foo)", 5, 5), ts.EndStencil("END STENCIL(name=bar; noendif=foo)", 20, 20), ], - "content": {"EndStencil": [{"name": "foo"}, {"name": "bar", "noendif": "foo"}]}, + "content": { + "EndStencil": [{"name": "foo"}, {"name": "bar", "noendif": "foo"}] + }, }, ), ], @@ -314,7 +324,10 @@ def test_update_field_tolerances(self): FieldAssociationData("x", "i", 3, rel_tol="0.01", abs_tol="0.1"), FieldAssociationData("y", "i", 3, rel_tol="0.001"), ] - assert self.factory._update_tolerances(named_args, self.mock_fields) == expected_fields + assert ( + self.factory._update_tolerances(named_args, self.mock_fields) + == expected_fields + ) def test_update_field_tolerances_not_all_fields(self): # Test that tolerance is not set for fields that are not provided in the named_args. @@ -326,9 +339,15 @@ def test_update_field_tolerances_not_all_fields(self): FieldAssociationData("x", "i", 3, rel_tol="0.01", abs_tol="0.1"), FieldAssociationData("y", "i", 3), ] - assert self.factory._update_tolerances(named_args, self.mock_fields) == expected_fields + assert ( + self.factory._update_tolerances(named_args, self.mock_fields) + == expected_fields + ) def test_update_field_tolerances_no_tolerances(self): # Test that fields are not updated if named_args does not contain any tolerances. named_args = {} - assert self.factory._update_tolerances(named_args, self.mock_fields) == self.mock_fields + assert ( + self.factory._update_tolerances(named_args, self.mock_fields) + == self.mock_fields + ) diff --git a/tools/tests/liskov/test_external.py b/tools/tests/liskov/test_external.py index 1e0de1b5ca..e912a09932 100644 --- a/tools/tests/liskov/test_external.py +++ b/tools/tests/liskov/test_external.py @@ -22,7 +22,10 @@ IntegrationCodeInterface, StartStencilData, ) -from icon4pytools.liskov.external.exceptions import IncompatibleFieldError, UnknownStencilError +from icon4pytools.liskov.external.exceptions import ( + IncompatibleFieldError, + UnknownStencilError, +) from icon4pytools.liskov.external.gt4py import UpdateFieldsWithGt4PyStencils diff --git a/tools/tests/liskov/test_generation.py b/tools/tests/liskov/test_generation.py index dc16b09c6a..03bd4d5d49 100644 --- a/tools/tests/liskov/test_generation.py +++ b/tools/tests/liskov/test_generation.py @@ -31,7 +31,9 @@ ) # TODO: fix tests to adapt to new custom output fields -from icon4pytools.liskov.codegen.serialisation.generate import SerialisationCodeGenerator +from icon4pytools.liskov.codegen.serialisation.generate import ( + SerialisationCodeGenerator, +) from icon4pytools.liskov.codegen.serialisation.interface import ( FieldSerialisationData, ImportData, @@ -49,7 +51,9 @@ def integration_code_interface(): fields=[ FieldAssociationData("scalar1", "scalar1", inp=True, out=False, dims=None), FieldAssociationData("inp1", "inp1(:,:,1)", inp=True, out=False, dims=2), - FieldAssociationData("out1", "out1(:,:,1)", inp=False, out=True, dims=2, abs_tol="0.5"), + FieldAssociationData( + "out1", "out1(:,:,1)", inp=False, out=True, dims=2, abs_tol="0.5" + ), FieldAssociationData( "out2", "p_nh%prog(nnew)%out2(:,:,1)", @@ -58,8 +62,12 @@ def integration_code_interface(): dims=3, abs_tol="0.2", ), - FieldAssociationData("out3", "p_nh%prog(nnew)%w(:,:,jb)", inp=False, out=True, dims=2), - FieldAssociationData("out4", "p_nh%prog(nnew)%w(:,:,1,2)", inp=False, out=True, dims=3), + FieldAssociationData( + "out3", "p_nh%prog(nnew)%w(:,:,jb)", inp=False, out=True, dims=2 + ), + FieldAssociationData( + "out4", "p_nh%prog(nnew)%w(:,:,1,2)", inp=False, out=True, dims=3 + ), FieldAssociationData( "out5", "p_nh%prog(nnew)%w(:,:,:,ntnd)", inp=False, out=True, dims=3 ), @@ -206,7 +214,9 @@ def expected_insert_source(): @pytest.fixture def integration_code_generator(integration_code_interface): - return IntegrationCodeGenerator(integration_code_interface, profile=True, metadatagen=False) + return IntegrationCodeGenerator( + integration_code_interface, profile=True, metadatagen=False + ) def test_integration_code_generation( @@ -327,7 +337,9 @@ def expected_savepoints(): def test_serialisation_code_generation( serialisation_code_interface, expected_savepoints, multinode ): - generated = SerialisationCodeGenerator(serialisation_code_interface, multinode=multinode)() + generated = SerialisationCodeGenerator( + serialisation_code_interface, multinode=multinode + )() if multinode: assert len(generated) == 3 diff --git a/tools/tests/liskov/test_serialisation_deserialiser.py b/tools/tests/liskov/test_serialisation_deserialiser.py index 0431086beb..380b8c14b9 100644 --- a/tools/tests/liskov/test_serialisation_deserialiser.py +++ b/tools/tests/liskov/test_serialisation_deserialiser.py @@ -109,7 +109,9 @@ def parsed_dict(): ], "StartProfile": [{"name": "apply_nabla2_to_vn_in_lateral_boundary"}], "EndProfile": [{}], - "EndStencil": [{"name": "apply_nabla2_to_vn_in_lateral_boundary", "noprofile": "True"}], + "EndStencil": [ + {"name": "apply_nabla2_to_vn_in_lateral_boundary", "noprofile": "True"} + ], "EndCreate": [{}], }, } diff --git a/tools/tests/liskov/test_validation.py b/tools/tests/liskov/test_validation.py index d6fbc4f927..e02c08c46c 100644 --- a/tools/tests/liskov/test_validation.py +++ b/tools/tests/liskov/test_validation.py @@ -20,7 +20,12 @@ RequiredDirectivesError, UnbalancedStencilDirectiveError, ) -from icon4pytools.liskov.parsing.parse import Declare, DirectivesParser, Imports, StartStencil +from icon4pytools.liskov.parsing.parse import ( + Declare, + DirectivesParser, + Imports, + StartStencil, +) from icon4pytools.liskov.parsing.validation import DirectiveSyntaxValidator from .conftest import insert_new_lines, scan_for_directives @@ -72,7 +77,9 @@ def test_directive_syntax_validator(directive): "!$DSL IMPORTS()", ], ) -def test_directive_semantics_validation_repeated_directives(make_f90_tmpfile, directive): +def test_directive_semantics_validation_repeated_directives( + make_f90_tmpfile, directive +): fpath = make_f90_tmpfile(content=SINGLE_STENCIL) opath = fpath.with_suffix(".gen") insert_new_lines(fpath, [directive]) @@ -108,7 +115,9 @@ def test_directive_semantics_validation_repeated_stencil(make_f90_tmpfile, direc """!$DSL END STENCIL(name=apply_nabla2_to_vn_in_lateral_boundary; noprofile=True)""", ], ) -def test_directive_semantics_validation_required_directives(make_f90_tmpfile, directive): +def test_directive_semantics_validation_required_directives( + make_f90_tmpfile, directive +): new = SINGLE_STENCIL.replace(directive, "") fpath = make_f90_tmpfile(content=new) opath = fpath.with_suffix(".gen") diff --git a/tools/tests/liskov/test_writer.py b/tools/tests/liskov/test_writer.py index e24410120e..ed8cb512c0 100644 --- a/tools/tests/liskov/test_writer.py +++ b/tools/tests/liskov/test_writer.py @@ -64,7 +64,9 @@ def test_insert_generated_code(): "another line", "generated code2\n", ] - assert CodegenWriter._insert_generated_code(current_file, generated) == expected_output + assert ( + CodegenWriter._insert_generated_code(current_file, generated) == expected_output + ) def test_write_file(): From 37d33258817ad1488d4dd832e25da9bc06e33ea0 Mon Sep 17 00:00:00 2001 From: samkellerhals Date: Tue, 11 Jul 2023 14:53:56 +0200 Subject: [PATCH 4/4] Run precommit --- tools/src/icon4pytools/common/logger.py | 4 +- tools/src/icon4pytools/f2ser/cli.py | 8 +-- tools/src/icon4pytools/f2ser/deserialise.py | 12 ++--- tools/src/icon4pytools/f2ser/parse.py | 51 +++++++------------ .../icon4pygen/bindings/codegen/cpp.py | 16 ++---- .../icon4pygen/bindings/codegen/f90.py | 50 +++++------------- .../bindings/codegen/render/field.py | 10 +--- .../bindings/codegen/render/offset.py | 9 +--- .../icon4pygen/bindings/entities.py | 14 ++--- .../icon4pygen/bindings/locations.py | 4 +- .../icon4pygen/bindings/workflow.py | 4 +- tools/src/icon4pytools/icon4pygen/cli.py | 4 +- tools/src/icon4pytools/icon4pygen/metadata.py | 16 ++---- tools/src/icon4pytools/liskov/cli.py | 8 +-- .../liskov/codegen/integration/deserialise.py | 45 ++++------------ .../liskov/codegen/integration/template.py | 9 +--- .../codegen/serialisation/deserialise.py | 4 +- .../liskov/codegen/serialisation/generate.py | 4 +- .../liskov/codegen/serialisation/template.py | 8 +-- .../liskov/codegen/shared/deserialise.py | 4 +- .../src/icon4pytools/liskov/external/gt4py.py | 9 +--- .../src/icon4pytools/liskov/parsing/parse.py | 8 +-- .../icon4pytools/liskov/parsing/validation.py | 37 ++++---------- .../liskov/pipeline/collection.py | 12 ++--- tools/tests/f2ser/test_f2ser_codegen.py | 8 +-- .../tests/f2ser/test_granule_deserialiser.py | 6 +-- tools/tests/f2ser/test_parsing.py | 4 +- tools/tests/icon4pygen/test_backend.py | 4 +- .../tests/icon4pygen/test_field_rendering.py | 8 +-- .../liskov/test_directives_deserialiser.py | 35 +++---------- tools/tests/liskov/test_external.py | 5 +- tools/tests/liskov/test_generation.py | 24 +++------ .../liskov/test_serialisation_deserialiser.py | 4 +- tools/tests/liskov/test_validation.py | 15 ++---- tools/tests/liskov/test_writer.py | 4 +- 35 files changed, 117 insertions(+), 350 deletions(-) diff --git a/tools/src/icon4pytools/common/logger.py b/tools/src/icon4pytools/common/logger.py index 636a582d49..d7f7341fed 100644 --- a/tools/src/icon4pytools/common/logger.py +++ b/tools/src/icon4pytools/common/logger.py @@ -18,9 +18,7 @@ def setup_logger(name: str, log_level: int = logging.INFO) -> logging.Logger: """Set up a logger with a given name and log level.""" logger = logging.getLogger(name) logger.setLevel(log_level) - formatter = logging.Formatter( - "%(asctime)s - %(name)s - %(levelname)s - %(message)s" - ) + formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") stream_handler = logging.StreamHandler() stream_handler.setFormatter(formatter) logger.addHandler(stream_handler) diff --git a/tools/src/icon4pytools/f2ser/cli.py b/tools/src/icon4pytools/f2ser/cli.py index a122a8d893..290d8a54f9 100644 --- a/tools/src/icon4pytools/f2ser/cli.py +++ b/tools/src/icon4pytools/f2ser/cli.py @@ -18,18 +18,14 @@ from icon4pytools.f2ser.deserialise import ParsedGranuleDeserialiser from icon4pytools.f2ser.parse import GranuleParser -from icon4pytools.liskov.codegen.serialisation.generate import ( - SerialisationCodeGenerator, -) +from icon4pytools.liskov.codegen.serialisation.generate import SerialisationCodeGenerator from icon4pytools.liskov.codegen.shared.write import CodegenWriter @click.command("icon_f2ser") @click.argument( "granule_path", - type=click.Path( - exists=True, dir_okay=False, resolve_path=True, path_type=pathlib.Path - ), + type=click.Path(exists=True, dir_okay=False, resolve_path=True, path_type=pathlib.Path), ) @click.argument( "output_filepath", diff --git a/tools/src/icon4pytools/f2ser/deserialise.py b/tools/src/icon4pytools/f2ser/deserialise.py index e0082eed66..883daddc94 100644 --- a/tools/src/icon4pytools/f2ser/deserialise.py +++ b/tools/src/icon4pytools/f2ser/deserialise.py @@ -55,9 +55,7 @@ def _make_savepoints(self) -> None: for intent, var_dict in intent_dict.items(): self._create_savepoint(subroutine_name, intent, var_dict) - def _create_savepoint( - self, subroutine_name: str, intent: str, var_dict: dict - ) -> None: + def _create_savepoint(self, subroutine_name: str, intent: str, var_dict: dict) -> None: """Create a savepoint for the given variables. Args: @@ -73,9 +71,7 @@ def _create_savepoint( FieldSerialisationData( variable=var_name, association=self._create_association(var_data, var_name), - decomposed=var_data["decomposed"] - if var_data.get("decomposed") - else False, + decomposed=var_data["decomposed"] if var_data.get("decomposed") else False, dimension=var_data.get("dimension"), typespec=var_data.get("typespec"), typename=var_data.get("typename"), @@ -139,9 +135,7 @@ def _make_init_data(self) -> None: for intent, var_dict in intent_dict.items() if intent == "in" ][0] - startln = self._get_codegen_line( - first_intent_in_subroutine["codegen_ctx"], "init" - ) + startln = self._get_codegen_line(first_intent_in_subroutine["codegen_ctx"], "init") self.data["Init"] = InitData( startln=startln, directory=self.directory, diff --git a/tools/src/icon4pytools/f2ser/parse.py b/tools/src/icon4pytools/f2ser/parse.py index 3d8147da42..a68fe99003 100644 --- a/tools/src/icon4pytools/f2ser/parse.py +++ b/tools/src/icon4pytools/f2ser/parse.py @@ -59,9 +59,7 @@ class GranuleParser: parsed_types = parser() """ - def __init__( - self, granule: Path, dependencies: Optional[list[Path]] = None - ) -> None: + def __init__(self, granule: Path, dependencies: Optional[list[Path]] = None) -> None: self.granule_path = granule self.dependencies = dependencies @@ -72,7 +70,7 @@ def __call__(self) -> ParsedGranule: return ParsedGranule(subroutines=subroutines, last_import_ln=last_import_ln) def _find_last_fortran_use_statement(self) -> Optional[int]: - """Finds the line number of the last Fortran USE statement in the code. + """Find the line number of the last Fortran USE statement in the code. Returns: int: the line number of the last USE statement, or None if no USE statement is found. @@ -95,7 +93,7 @@ def _find_last_fortran_use_statement(self) -> Optional[int]: return None def _read_code_from_file(self) -> str: - """Reads the content of the granule and returns it as a string.""" + """Read the content of the granule and returns it as a string.""" with open(self.granule_path) as f: code = f.read() return code @@ -103,12 +101,9 @@ def _read_code_from_file(self) -> str: def parse_subroutines(self): subroutines = self._extract_subroutines(crack(self.granule_path)) variables_grouped_by_intent = { - name: self._extract_intent_vars(routine) - for name, routine in subroutines.items() + name: self._extract_intent_vars(routine) for name, routine in subroutines.items() } - intrinsic_type_vars, derived_type_vars = self._parse_types( - variables_grouped_by_intent - ) + intrinsic_type_vars, derived_type_vars = self._parse_types(variables_grouped_by_intent) combined_type_vars = self._combine_types(derived_type_vars, intrinsic_type_vars) with_lines = self._update_with_codegen_lines(combined_type_vars) return with_lines @@ -129,9 +124,7 @@ def _extract_subroutines(self, parsed: dict[str, Any]) -> dict[str, Any]: subroutines[name] = elt if len(subroutines) != 2: - raise ParsingError( - f"Did not find _init and _run subroutines in {self.granule_path}" - ) + raise ParsingError(f"Did not find _init and _run subroutines in {self.granule_path}") return subroutines @@ -245,9 +238,7 @@ def _decompose_derived_types(derived_types: dict) -> dict: new_type_name = f"{var_name}_{subtype_name}" new_var_dict = var_dict.copy() new_var_dict.update(subtype_spec) - decomposed_vars[subroutine][intent][ - new_type_name - ] = new_var_dict + decomposed_vars[subroutine][intent][new_type_name] = new_var_dict new_var_dict["ptr_var"] = subtype_name else: decomposed_vars[subroutine][intent][var_name] = var_dict @@ -284,9 +275,9 @@ def _update_with_codegen_lines(self, parsed_types: dict) -> dict: with_lines = deepcopy(parsed_types) for subroutine in with_lines: for intent in with_lines[subroutine]: - with_lines[subroutine][intent][ - "codegen_ctx" - ] = self._get_subroutine_lines(subroutine) + with_lines[subroutine][intent]["codegen_ctx"] = self._get_subroutine_lines( + subroutine + ) return with_lines def _get_subroutine_lines(self, subroutine_name: str) -> CodegenContext: @@ -300,9 +291,7 @@ def _get_subroutine_lines(self, subroutine_name: str) -> CodegenContext: """ code = self._read_code_from_file() - start_subroutine_ln, end_subroutine_ln = self._find_subroutine_lines( - code, subroutine_name - ) + start_subroutine_ln, end_subroutine_ln = self._find_subroutine_lines(code, subroutine_name) variable_declaration_ln = self._find_variable_declarations( code, start_subroutine_ln, end_subroutine_ln @@ -314,21 +303,17 @@ def _get_subroutine_lines(self, subroutine_name: str) -> CodegenContext: ( first_declaration_ln, last_declaration_ln, - ) = self._get_variable_declaration_bounds( - variable_declaration_ln, start_subroutine_ln - ) + ) = self._get_variable_declaration_bounds(variable_declaration_ln, start_subroutine_ln) pre_end_subroutine_ln = ( end_subroutine_ln - 1 ) # we want to generate the code before the end of the subroutine - return CodegenContext( - first_declaration_ln, last_declaration_ln, pre_end_subroutine_ln - ) + return CodegenContext(first_declaration_ln, last_declaration_ln, pre_end_subroutine_ln) @staticmethod def _find_subroutine_lines(code: str, subroutine_name: str) -> tuple[int]: - """Finds line numbers of a subroutine within a code block. + """Find line numbers of a subroutine within a code block. Args: code (str): The code block to search for the subroutine. @@ -351,7 +336,7 @@ def _find_subroutine_lines(code: str, subroutine_name: str) -> tuple[int]: def _find_variable_declarations( code: str, start_subroutine_ln: int, end_subroutine_ln: int ) -> list: - """Finds line numbers of variable declarations within a code block. + """Find line numbers of variable declarations within a code block. Args: code (str): The code block to search for variable declarations. @@ -370,9 +355,7 @@ def _find_variable_declarations( is_multiline_declaration = False declaration_pattern_lines = [] - for i, line in enumerate( - code.splitlines()[start_subroutine_ln:end_subroutine_ln] - ): + for i, line in enumerate(code.splitlines()[start_subroutine_ln:end_subroutine_ln]): if not is_multiline_declaration: if re.search(declaration_pattern, line): declaration_pattern_lines.append(i) @@ -390,7 +373,7 @@ def _find_variable_declarations( def _get_variable_declaration_bounds( declaration_pattern_lines: list, start_subroutine_ln: int ) -> tuple: - """Returns the line numbers of the bounds for a variable declaration block. + """Return the line numbers of the bounds for a variable declaration block. Args: declaration_pattern_lines (list): List of line numbers representing the relative positions of lines within the declaration block. diff --git a/tools/src/icon4pytools/icon4pygen/bindings/codegen/cpp.py b/tools/src/icon4pytools/icon4pygen/bindings/codegen/cpp.py index 73c7c4a234..ba6951f572 100644 --- a/tools/src/icon4pytools/icon4pygen/bindings/codegen/cpp.py +++ b/tools/src/icon4pytools/icon4pygen/bindings/codegen/cpp.py @@ -27,9 +27,7 @@ run_func_declaration, run_verify_func_declaration, ) -from icon4pytools.icon4pygen.bindings.codegen.render.offset import ( - GpuTriMeshOffsetRenderer, -) +from icon4pytools.icon4pygen.bindings.codegen.render.offset import GpuTriMeshOffsetRenderer from icon4pytools.icon4pygen.bindings.entities import Field, Offset from icon4pytools.icon4pygen.bindings.utils import write_string @@ -665,12 +663,8 @@ def _get_field_data(self) -> tuple: ] sparse_fields = [field for field in self.fields if field.is_sparse()] compound_fields = [field for field in self.fields if field.is_compound()] - sparse_offsets = [ - offset for offset in self.offsets if not offset.is_compound_location() - ] - strided_offsets = [ - offset for offset in self.offsets if offset.is_compound_location() - ] + sparse_offsets = [offset for offset in self.offsets if not offset.is_compound_location()] + strided_offsets = [offset for offset in self.offsets if offset.is_compound_location()] all_fields = self.fields offsets = dict(sparse=sparse_offsets, strided=strided_offsets) @@ -715,9 +709,7 @@ def __post_init__(self) -> None: # type: ignore ), public_utilities=PublicUtilities(fields=fields["output"]), copy_pointers=CopyPointers(fields=self.fields), - private_members=PrivateMembers( - fields=self.fields, out_fields=fields["output"] - ), + private_members=PrivateMembers(fields=self.fields, out_fields=fields["output"]), setup_func=StencilClassSetupFunc( funcname=self.stencil_name, out_fields=fields["output"], diff --git a/tools/src/icon4pytools/icon4pygen/bindings/codegen/f90.py b/tools/src/icon4pytools/icon4pygen/bindings/codegen/f90.py index a72aaedb67..a3cc0a3e41 100644 --- a/tools/src/icon4pytools/icon4pygen/bindings/codegen/f90.py +++ b/tools/src/icon4pytools/icon4pygen/bindings/codegen/f90.py @@ -164,9 +164,7 @@ class F90Generator(TemplatedGenerator): F90Field = as_jinja("{{ name }}{% if suffix %}_{{ suffix }}{% endif %}") - F90OpenACCField = as_jinja( - "!$ACC {{ name }}{% if suffix %}_{{ suffix }}{% endif %}" - ) + F90OpenACCField = as_jinja("!$ACC {{ name }}{% if suffix %}_{{ suffix }}{% endif %}") F90TypedField = as_jinja( "{{ dtype }}, {% if dims %}{{ dims }},{% endif %} target {% if _this_node.optional %} , optional {% endif %}:: {{ name }}{% if suffix %}_{{ suffix }}{% endif %} " @@ -228,13 +226,10 @@ def __post_init__(self) -> None: # type: ignore ) for field in self.all_fields ] + [ - F90TypedField(name=name, dtype="integer(c_int)", dims="value") - for name in _DOMAIN_ARGS + F90TypedField(name=name, dtype="integer(c_int)", dims="value") for name in _DOMAIN_ARGS ] - self.params = F90EntityList( - fields=param_fields, line_end=", &", line_end_last=" &" - ) + self.params = F90EntityList(fields=param_fields, line_end=", &", line_end_last=" &") self.binds = F90EntityList(fields=bind_fields) @@ -278,9 +273,7 @@ def __post_init__(self) -> None: # type: ignore ) for field in self.tol_fields: - param_fields += [ - F90Field(name=field.name, suffix=s) for s in ["rel_tol", "abs_tol"] - ] + param_fields += [F90Field(name=field.name, suffix=s) for s in ["rel_tol", "abs_tol"]] bind_fields += [ F90TypedField( name=field.name, @@ -291,9 +284,7 @@ def __post_init__(self) -> None: # type: ignore for s in ["rel_tol", "abs_tol"] ] - self.params = F90EntityList( - fields=param_fields, line_end=", &", line_end_last=" &" - ) + self.params = F90EntityList(fields=param_fields, line_end=", &", line_end_last=" &") self.binds = F90EntityList(fields=bind_fields) @@ -337,9 +328,7 @@ def __post_init__(self) -> None: # type: ignore for field in self.out_fields ] - self.params = F90EntityList( - fields=param_fields, line_end=", &", line_end_last=" &" - ) + self.params = F90EntityList(fields=param_fields, line_end=", &", line_end_last=" &") self.binds = F90EntityList(fields=bind_fields) @@ -401,9 +390,7 @@ def __post_init__(self) -> None: # type: ignore for field in self.tol_fields ] open_acc_fields = [ - F90OpenACCField(name=field.name) - for field in self.all_fields - if field.rank() != 0 + F90OpenACCField(name=field.name) for field in self.all_fields if field.rank() != 0 ] + [ F90OpenACCField(name=field.name, suffix="before") for field in self.out_fields @@ -433,9 +420,7 @@ def __post_init__(self) -> None: # type: ignore ] for field in self.tol_fields: - param_fields += [ - F90Field(name=field.name, suffix=s) for s in ["rel_tol", "abs_tol"] - ] + param_fields += [F90Field(name=field.name, suffix=s) for s in ["rel_tol", "abs_tol"]] bind_fields += [ F90TypedField( name=field.name, @@ -447,25 +432,18 @@ def __post_init__(self) -> None: # type: ignore for s in ["rel_tol", "abs_tol"] ] run_ver_param_fields += [ - F90Field(name=field.name, suffix=s) - for s in ["rel_err_tol", "abs_err_tol"] + F90Field(name=field.name, suffix=s) for s in ["rel_err_tol", "abs_err_tol"] ] - self.params = F90EntityList( - fields=param_fields, line_end=", &", line_end_last=" &" - ) + self.params = F90EntityList(fields=param_fields, line_end=", &", line_end_last=" &") self.binds = F90EntityList(fields=bind_fields) self.tol_decls = F90EntityList(fields=tol_fields) self.conditionals = F90EntityList(fields=cond_fields) - self.openacc = F90EntityList( - fields=open_acc_fields, line_end=", &", line_end_last=" &" - ) + self.openacc = F90EntityList(fields=open_acc_fields, line_end=", &", line_end_last=" &") self.run_ver_params = F90EntityList( fields=run_ver_param_fields, line_end=", &", line_end_last=" &" ) - self.run_params = F90EntityList( - fields=run_param_fields, line_end=", &", line_end_last=" &" - ) + self.run_params = F90EntityList(fields=run_param_fields, line_end=", &", line_end_last=" &") class F90WrapSetupFun(Node): @@ -536,9 +514,7 @@ def __post_init__(self) -> None: # type: ignore ] ] + [F90Field(name=field.name, suffix="kvert_max") for field in self.out_fields] - self.params = F90EntityList( - fields=param_fields, line_end=", &", line_end_last=" &" - ) + self.params = F90EntityList(fields=param_fields, line_end=", &", line_end_last=" &") self.binds = F90EntityList(fields=bind_fields) self.vert_decls = F90EntityList(fields=vert_fields) self.vert_conditionals = F90EntityList(fields=vert_conditionals_fields) diff --git a/tools/src/icon4pytools/icon4pygen/bindings/codegen/render/field.py b/tools/src/icon4pytools/icon4pygen/bindings/codegen/render/field.py index 2c170f96c7..e9106fba0e 100644 --- a/tools/src/icon4pytools/icon4pygen/bindings/codegen/render/field.py +++ b/tools/src/icon4pytools/icon4pygen/bindings/codegen/render/field.py @@ -34,11 +34,7 @@ def render_pointer(self) -> str: def render_dim_tags(self) -> str: """Render c++ dimension tags.""" tags = [] - if ( - self.entity.is_dense() - or self.entity.is_sparse() - or self.entity.is_compound() - ): + if self.entity.is_dense() or self.entity.is_sparse() or self.entity.is_compound(): tags.append("unstructured::dim::horizontal") if self.entity.has_vertical_dimension: tags.append("unstructured::dim::vertical") @@ -50,9 +46,7 @@ def render_sid(self) -> str: raise BindingsRenderingException("can not render sid of a scalar") # We want to compute the rank without the sparse dimension, i.e. if a field is horizontal, vertical or both. - dense_rank = self.entity.rank() - int( - self.entity.is_sparse() or self.entity.is_compound() - ) + dense_rank = self.entity.rank() - int(self.entity.is_sparse() or self.entity.is_compound()) if dense_rank == 1: values_str = "1" elif self.entity.is_compound(): diff --git a/tools/src/icon4pytools/icon4pygen/bindings/codegen/render/offset.py b/tools/src/icon4pytools/icon4pygen/bindings/codegen/render/offset.py index c0a1d17811..aa40dcf2ad 100644 --- a/tools/src/icon4pytools/icon4pygen/bindings/codegen/render/offset.py +++ b/tools/src/icon4pytools/icon4pygen/bindings/codegen/render/offset.py @@ -77,9 +77,7 @@ def __init__(self, offsets: Sequence[OffsetEntity]): def make_table_vars(self) -> list[str]: if not self.has_offsets: return [] - unique_offsets = sorted( - {self._make_table_var(offset) for offset in self.offsets} - ) + unique_offsets = sorted({self._make_table_var(offset) for offset in self.offsets}) return list(unique_offsets) def make_neighbor_tables(self) -> list[str]: @@ -103,7 +101,4 @@ def _make_table_var(offset: OffsetEntity) -> str: @staticmethod def _make_location_type(offset: OffsetEntity) -> list[str]: - return [ - f"LocationType::{loc.render_location_type()}" - for loc in offset.target[1].chain - ] + return [f"LocationType::{loc.render_location_type()}" for loc in offset.target[1].chain] diff --git a/tools/src/icon4pytools/icon4pygen/bindings/entities.py b/tools/src/icon4pytools/icon4pygen/bindings/entities.py index 44dc97c1f1..cad08c7c0d 100644 --- a/tools/src/icon4pytools/icon4pygen/bindings/entities.py +++ b/tools/src/icon4pytools/icon4pygen/bindings/entities.py @@ -19,11 +19,7 @@ from icon4pytools.icon4pygen.bindings.codegen.render.field import FieldRenderer from icon4pytools.icon4pygen.bindings.codegen.render.offset import OffsetRenderer -from icon4pytools.icon4pygen.bindings.codegen.types import ( - FieldEntity, - FieldIntent, - OffsetEntity, -) +from icon4pytools.icon4pygen.bindings.codegen.types import FieldEntity, FieldIntent, OffsetEntity from icon4pytools.icon4pygen.bindings.exceptions import BindingsTypeConsistencyException from icon4pytools.icon4pygen.bindings.locations import ( BASIC_LOCATIONS, @@ -75,9 +71,7 @@ def _handle_source(chain: str) -> Union[BasicLocation, CompoundLocation]: if source in [str(loc()) for loc in BASIC_LOCATIONS.values()]: return chain_from_str(source)[0] - elif all( - char in [str(loc()) for loc in BASIC_LOCATIONS.values()] for char in source - ): + elif all(char in [str(loc()) for loc in BASIC_LOCATIONS.values()] for char in source): return CompoundLocation(chain_from_str(source)) else: raise BindingsTypeConsistencyException(f"Invalid source {source}") @@ -162,9 +156,7 @@ def _update_horizontal_location(self, field: past.DataSymbol) -> None: if not isinstance(field.type, ts.FieldType): return - maybe_horizontal_dimension = list( - filter(lambda dim: dim.value != "K", field.type.dims) - ) + maybe_horizontal_dimension = list(filter(lambda dim: dim.value != "K", field.type.dims)) # early abort if field is vertical if not len(maybe_horizontal_dimension): diff --git a/tools/src/icon4pytools/icon4pygen/bindings/locations.py b/tools/src/icon4pytools/icon4pygen/bindings/locations.py index 35a825ddf7..4071da48b6 100644 --- a/tools/src/icon4pytools/icon4pygen/bindings/locations.py +++ b/tools/src/icon4pytools/icon4pygen/bindings/locations.py @@ -64,9 +64,7 @@ def __init__(self, chain: list[BasicLocation]) -> None: if is_valid(chain): self.chain = chain else: - raise Exception( - f"chain {chain} contains two of the same elements in succession" - ) + raise Exception(f"chain {chain} contains two of the same elements in succession") def __iter__(self) -> Iterator[BasicLocation]: return iter(self.chain) diff --git a/tools/src/icon4pytools/icon4pygen/bindings/workflow.py b/tools/src/icon4pytools/icon4pygen/bindings/workflow.py index 108f602fd6..83a5d1f6ee 100644 --- a/tools/src/icon4pytools/icon4pygen/bindings/workflow.py +++ b/tools/src/icon4pytools/icon4pygen/bindings/workflow.py @@ -34,9 +34,7 @@ class PyBindGen: Furthermore, we also serialise data to .csv or .vtk files in case of verification failure. """ - def __init__( - self, stencil_info: StencilInfo, levels_per_thread: int, block_size: int - ) -> None: + def __init__(self, stencil_info: StencilInfo, levels_per_thread: int, block_size: int) -> None: self.stencil_name = stencil_info.itir.id self.fields, self.offsets = self._stencil_info_to_binding_type(stencil_info) self.levels_per_thread = levels_per_thread diff --git a/tools/src/icon4pytools/icon4pygen/cli.py b/tools/src/icon4pytools/icon4pygen/cli.py index 0a1f0f0233..66180adf44 100644 --- a/tools/src/icon4pytools/icon4pygen/cli.py +++ b/tools/src/icon4pytools/icon4pygen/cli.py @@ -45,9 +45,7 @@ def shell_complete(self, ctx, param, incomplete): @click.argument("fencil", type=ModuleType()) @click.argument("block_size", type=int, default=128) @click.argument("levels_per_thread", type=int, default=4) -@click.option( - "--is_global", is_flag=True, type=bool, help="Whether this is a global run." -) +@click.option("--is_global", is_flag=True, type=bool, help="Whether this is a global run.") @click.argument( "outpath", type=click.Path(dir_okay=True, resolve_path=True, path_type=pathlib.Path), diff --git a/tools/src/icon4pytools/icon4pygen/metadata.py b/tools/src/icon4pytools/icon4pygen/metadata.py index 0b4a86e02f..1f1b42e4cb 100644 --- a/tools/src/icon4pytools/icon4pygen/metadata.py +++ b/tools/src/icon4pytools/icon4pygen/metadata.py @@ -142,9 +142,7 @@ def get_fvprog(fencil_def: Program | Any) -> Program: return fvprog -def provide_offset( - offset: str, is_global: bool = False -) -> DummyConnectivity | Dimension: +def provide_offset(offset: str, is_global: bool = False) -> DummyConnectivity | Dimension: if offset == Koff.value: assert len(Koff.target) == 1 assert Koff.source == Koff.target[0] @@ -166,9 +164,7 @@ def provide_neighbor_table(chain: str, is_global: bool) -> DummyConnectivity: and pass the tokens after to the algorithm below """ # note: this seems really brittle. maybe agree on a keyword to indicate new sparse fields? - new_sparse_field = any( - len(token) > 1 for token in chain.split("2") - ) and not chain.endswith("O") + new_sparse_field = any(len(token) > 1 for token in chain.split("2")) and not chain.endswith("O") if new_sparse_field: chain = chain.split("2")[1] skip_values = False @@ -199,13 +195,9 @@ def provide_neighbor_table(chain: str, is_global: bool) -> DummyConnectivity: def scan_for_offsets(fvprog: Program) -> list[eve.concepts.SymbolRef]: """Scan PAST node for offsets and return a set of all offsets.""" - all_types = ( - fvprog.past_node.pre_walk_values().if_isinstance(past.Symbol).getattr("type") - ) + all_types = fvprog.past_node.pre_walk_values().if_isinstance(past.Symbol).getattr("type") all_field_types = [ - symbol_type - for symbol_type in all_types - if isinstance(symbol_type, ts.FieldType) + symbol_type for symbol_type in all_types if isinstance(symbol_type, ts.FieldType) ] all_dims = set(i for j in all_field_types for i in j.dims) all_offset_labels = ( diff --git a/tools/src/icon4pytools/liskov/cli.py b/tools/src/icon4pytools/liskov/cli.py index 806bcdcec8..a16e0a428e 100644 --- a/tools/src/icon4pytools/liskov/cli.py +++ b/tools/src/icon4pytools/liskov/cli.py @@ -52,9 +52,7 @@ def main(ctx): ) @click.argument( "input_path", - type=click.Path( - exists=True, dir_okay=False, resolve_path=True, path_type=pathlib.Path - ), + type=click.Path(exists=True, dir_okay=False, resolve_path=True, path_type=pathlib.Path), ) @click.argument( "output_path", @@ -84,9 +82,7 @@ def integrate(input_path, output_path, profile, metadatagen): ) @click.argument( "input_path", - type=click.Path( - exists=True, dir_okay=False, resolve_path=True, path_type=pathlib.Path - ), + type=click.Path(exists=True, dir_okay=False, resolve_path=True, path_type=pathlib.Path), ) @click.argument( "output_path", diff --git a/tools/src/icon4pytools/liskov/codegen/integration/deserialise.py b/tools/src/icon4pytools/liskov/codegen/integration/deserialise.py index 2f4d6c23da..4b4605c7cb 100644 --- a/tools/src/icon4pytools/liskov/codegen/integration/deserialise.py +++ b/tools/src/icon4pytools/liskov/codegen/integration/deserialise.py @@ -115,9 +115,7 @@ def __call__(self, parsed: ts.ParsedDict) -> CodeGenInput: class EndCreateDataFactory(OptionalMultiUseDataFactory): - directive_cls: Type[ - ts.ParsedDirective - ] = icon4pytools.liskov.parsing.parse.EndCreate + directive_cls: Type[ts.ParsedDirective] = icon4pytools.liskov.parsing.parse.EndCreate dtype: Type[EndCreateData] = EndCreateData @@ -132,16 +130,12 @@ class EndIfDataFactory(OptionalMultiUseDataFactory): class EndProfileDataFactory(OptionalMultiUseDataFactory): - directive_cls: Type[ - ts.ParsedDirective - ] = icon4pytools.liskov.parsing.parse.EndProfile + directive_cls: Type[ts.ParsedDirective] = icon4pytools.liskov.parsing.parse.EndProfile dtype: Type[EndProfileData] = EndProfileData class StartCreateDataFactory(DataFactoryBase): - directive_cls: Type[ - ts.ParsedDirective - ] = icon4pytools.liskov.parsing.parse.StartCreate + directive_cls: Type[ts.ParsedDirective] = icon4pytools.liskov.parsing.parse.StartCreate dtype: Type[StartCreateData] = StartCreateData def __call__(self, parsed: ts.ParsedDict) -> list[StartCreateData]: @@ -181,9 +175,7 @@ def __call__(self, parsed: ts.ParsedDict) -> list[DeclareData]: extracted = extract_directive(parsed["directives"], self.directive_cls) for i, directive in enumerate(extracted): named_args = parsed["content"]["Declare"][i] - ident_type = pop_item_from_dict( - named_args, "type", DEFAULT_DECLARE_IDENT_TYPE - ) + ident_type = pop_item_from_dict(named_args, "type", DEFAULT_DECLARE_IDENT_TYPE) suffix = pop_item_from_dict(named_args, "suffix", DEFAULT_DECLARE_SUFFIX) deserialised.append( self.dtype( @@ -197,9 +189,7 @@ def __call__(self, parsed: ts.ParsedDict) -> list[DeclareData]: class StartProfileDataFactory(DataFactoryBase): - directive_cls: Type[ - ts.ParsedDirective - ] = icon4pytools.liskov.parsing.parse.StartProfile + directive_cls: Type[ts.ParsedDirective] = icon4pytools.liskov.parsing.parse.StartProfile dtype: Type[StartProfileData] = StartProfileData def __call__(self, parsed: ts.ParsedDict) -> list[StartProfileData]: @@ -208,16 +198,12 @@ def __call__(self, parsed: ts.ParsedDict) -> list[StartProfileData]: for i, directive in enumerate(extracted): named_args = parsed["content"]["StartProfile"][i] stencil_name = _extract_stencil_name(named_args, directive) - deserialised.append( - self.dtype(name=stencil_name, startln=directive.startln) - ) + deserialised.append(self.dtype(name=stencil_name, startln=directive.startln)) return deserialised class EndStencilDataFactory(DataFactoryBase): - directive_cls: Type[ - ts.ParsedDirective - ] = icon4pytools.liskov.parsing.parse.EndStencil + directive_cls: Type[ts.ParsedDirective] = icon4pytools.liskov.parsing.parse.EndStencil dtype: Type[EndStencilData] = EndStencilData def __call__(self, parsed: ts.ParsedDict) -> list[EndStencilData]: @@ -242,9 +228,7 @@ def __call__(self, parsed: ts.ParsedDict) -> list[EndStencilData]: class StartStencilDataFactory(DataFactoryBase): - directive_cls: Type[ - ts.ParsedDirective - ] = icon4pytools.liskov.parsing.parse.StartStencil + directive_cls: Type[ts.ParsedDirective] = icon4pytools.liskov.parsing.parse.StartStencil dtype: Type[StartStencilData] = StartStencilData def __call__(self, parsed: ts.ParsedDict) -> list[StartStencilData]: @@ -258,20 +242,13 @@ def __call__(self, parsed: ts.ParsedDict) -> list[StartStencilData]: """ deserialised = [] field_dimensions = flatten_list_of_dicts( - [ - DeclareDataFactory.get_field_dimensions(dim) - for dim in parsed["content"]["Declare"] - ] + [DeclareDataFactory.get_field_dimensions(dim) for dim in parsed["content"]["Declare"]] ) directives = extract_directive(parsed["directives"], self.directive_cls) for i, directive in enumerate(directives): named_args = parsed["content"]["StartStencil"][i] - acc_present = string_to_bool( - pop_item_from_dict(named_args, "accpresent", "true") - ) - mergecopy = string_to_bool( - pop_item_from_dict(named_args, "mergecopy", "false") - ) + acc_present = string_to_bool(pop_item_from_dict(named_args, "accpresent", "true")) + mergecopy = string_to_bool(pop_item_from_dict(named_args, "mergecopy", "false")) copies = string_to_bool(pop_item_from_dict(named_args, "copies", "true")) stencil_name = _extract_stencil_name(named_args, directive) bounds = self._make_bounds(named_args) diff --git a/tools/src/icon4pytools/liskov/codegen/integration/template.py b/tools/src/icon4pytools/liskov/codegen/integration/template.py index 9b313d21d3..dba9543864 100644 --- a/tools/src/icon4pytools/liskov/codegen/integration/template.py +++ b/tools/src/icon4pytools/liskov/codegen/integration/template.py @@ -20,10 +20,7 @@ from gt4py.eve.codegen import TemplatedGenerator from icon4pytools.liskov.codegen.integration.exceptions import UndeclaredFieldError -from icon4pytools.liskov.codegen.integration.interface import ( - DeclareData, - StartStencilData, -) +from icon4pytools.liskov.codegen.integration.interface import DeclareData, StartStencilData from icon4pytools.liskov.external.metadata import CodeMetadata @@ -241,9 +238,7 @@ class StartStencilStatement(eve.Node): def __post_init__(self) -> None: # type: ignore all_fields = [Field(**asdict(f)) for f in self.stencil_data.fields] - self.copy_declarations = [ - self.make_copy_declaration(f) for f in all_fields if f.out - ] + self.copy_declarations = [self.make_copy_declaration(f) for f in all_fields if f.out] self.acc_present = "PRESENT" if self.stencil_data.acc_present else "NONE" @staticmethod diff --git a/tools/src/icon4pytools/liskov/codegen/serialisation/deserialise.py b/tools/src/icon4pytools/liskov/codegen/serialisation/deserialise.py index d64adfc707..9db5533201 100644 --- a/tools/src/icon4pytools/liskov/codegen/serialisation/deserialise.py +++ b/tools/src/icon4pytools/liskov/codegen/serialisation/deserialise.py @@ -178,9 +178,7 @@ def _get_timestep_variables(stencil_name: str) -> dict: if "mo_icon_interpolation_scalar" in stencil_name: timestep_variables["jstep"] = "jstep_ptr" - timestep_variables[ - "mo_icon_interpolation_ctr" - ] = "mo_icon_interpolation_ctr" + timestep_variables["mo_icon_interpolation_ctr"] = "mo_icon_interpolation_ctr" if "mo_advection_traj" in stencil_name: timestep_variables["jstep"] = "jstep_ptr" diff --git a/tools/src/icon4pytools/liskov/codegen/serialisation/generate.py b/tools/src/icon4pytools/liskov/codegen/serialisation/generate.py index 0b2c343596..fed5b8f735 100644 --- a/tools/src/icon4pytools/liskov/codegen/serialisation/generate.py +++ b/tools/src/icon4pytools/liskov/codegen/serialisation/generate.py @@ -13,9 +13,7 @@ from typing import Any from icon4pytools.common.logger import setup_logger -from icon4pytools.liskov.codegen.serialisation.interface import ( - SerialisationCodeInterface, -) +from icon4pytools.liskov.codegen.serialisation.interface import SerialisationCodeInterface from icon4pytools.liskov.codegen.serialisation.template import ( ImportStatement, ImportStatementGenerator, diff --git a/tools/src/icon4pytools/liskov/codegen/serialisation/template.py b/tools/src/icon4pytools/liskov/codegen/serialisation/template.py index fe9d630ebd..8c9f8d5325 100644 --- a/tools/src/icon4pytools/liskov/codegen/serialisation/template.py +++ b/tools/src/icon4pytools/liskov/codegen/serialisation/template.py @@ -49,15 +49,11 @@ class SavepointStatement(eve.Node): multinode: bool standard_fields: StandardFields = eve.datamodels.field(init=False) decomposed_fields: DecomposedFields = eve.datamodels.field(init=False) - decomposed_field_declarations: DecomposedFieldDeclarations = eve.datamodels.field( - init=False - ) + decomposed_field_declarations: DecomposedFieldDeclarations = eve.datamodels.field(init=False) def __post_init__(self): self.standard_fields = StandardFields( - fields=[ - Field(**asdict(f)) for f in self.savepoint.fields if not f.decomposed - ] + fields=[Field(**asdict(f)) for f in self.savepoint.fields if not f.decomposed] ) self.decomposed_fields = DecomposedFields( fields=[Field(**asdict(f)) for f in self.savepoint.fields if f.decomposed] diff --git a/tools/src/icon4pytools/liskov/codegen/shared/deserialise.py b/tools/src/icon4pytools/liskov/codegen/shared/deserialise.py index 302b0d762f..e71ef9fe3f 100644 --- a/tools/src/icon4pytools/liskov/codegen/shared/deserialise.py +++ b/tools/src/icon4pytools/liskov/codegen/shared/deserialise.py @@ -16,9 +16,7 @@ import icon4pytools.liskov.parsing.types as ts from icon4pytools.common.logger import setup_logger from icon4pytools.liskov.codegen.integration.interface import IntegrationCodeInterface -from icon4pytools.liskov.codegen.serialisation.interface import ( - SerialisationCodeInterface, -) +from icon4pytools.liskov.codegen.serialisation.interface import SerialisationCodeInterface from icon4pytools.liskov.pipeline.definition import Step diff --git a/tools/src/icon4pytools/liskov/external/gt4py.py b/tools/src/icon4pytools/liskov/external/gt4py.py index 126abe2175..4ad14ef9d3 100644 --- a/tools/src/icon4pytools/liskov/external/gt4py.py +++ b/tools/src/icon4pytools/liskov/external/gt4py.py @@ -20,10 +20,7 @@ from icon4pytools.common.logger import setup_logger from icon4pytools.icon4pygen.metadata import get_stencil_info from icon4pytools.liskov.codegen.integration.interface import IntegrationCodeInterface -from icon4pytools.liskov.external.exceptions import ( - IncompatibleFieldError, - UnknownStencilError, -) +from icon4pytools.liskov.external.exceptions import IncompatibleFieldError, UnknownStencilError from icon4pytools.liskov.pipeline.definition import Step @@ -65,9 +62,7 @@ def _collect_icon4py_stencil(self, stencil_name: str) -> Program: err_counter += 1 if err_counter == len(self._STENCIL_PACKAGES): - raise UnknownStencilError( - f"Did not find module: {stencil_name} in icon4pytools." - ) + raise UnknownStencilError(f"Did not find module: {stencil_name} in icon4pytools.") module_members = getmembers(module) found_stencil = [elt for elt in module_members if elt[0] == stencil_name] diff --git a/tools/src/icon4pytools/liskov/parsing/parse.py b/tools/src/icon4pytools/liskov/parsing/parse.py index 984dae6ca0..02759621a5 100644 --- a/tools/src/icon4pytools/liskov/parsing/parse.py +++ b/tools/src/icon4pytools/liskov/parsing/parse.py @@ -81,18 +81,14 @@ def _determine_type( ) return typed - def _preprocess( - self, directives: Sequence[ts.ParsedDirective] - ) -> Sequence[ts.ParsedDirective]: + def _preprocess(self, directives: Sequence[ts.ParsedDirective]) -> Sequence[ts.ParsedDirective]: """Preprocess the directives by removing unnecessary characters and formatting the directive strings.""" return [ d.__class__(self._clean_string(d.string), d.startln, d.endln) # type: ignore for d in directives ] - def _run_validation_passes( - self, preprocessed: Sequence[ts.ParsedDirective] - ) -> None: + def _run_validation_passes(self, preprocessed: Sequence[ts.ParsedDirective]) -> None: """Run validation passes on the directives.""" for validator in VALIDATORS: validator(self.input_filepath).validate(preprocessed) diff --git a/tools/src/icon4pytools/liskov/parsing/validation.py b/tools/src/icon4pytools/liskov/parsing/validation.py index c27361b8b8..f05a97e720 100644 --- a/tools/src/icon4pytools/liskov/parsing/validation.py +++ b/tools/src/icon4pytools/liskov/parsing/validation.py @@ -25,10 +25,7 @@ RequiredDirectivesError, UnbalancedStencilDirectiveError, ) -from icon4pytools.liskov.parsing.utils import ( - print_parsed_directive, - remove_directive_types, -) +from icon4pytools.liskov.parsing.utils import print_parsed_directive, remove_directive_types logger = setup_logger(__name__) @@ -68,16 +65,12 @@ def validate(self, directives: list[ts.ParsedDirective]) -> None: self._validate_outer(d.string, d.pattern, d) self._validate_inner(d.string, d.pattern, d) - def _validate_outer( - self, to_validate: str, pattern: str, d: ts.ParsedDirective - ) -> None: + def _validate_outer(self, to_validate: str, pattern: str, d: ts.ParsedDirective) -> None: regex = f"{pattern}\\((.*)\\)" match = re.fullmatch(regex, to_validate) self.exception_handler.check_for_matches(d, match, regex, self.filepath) - def _validate_inner( - self, to_validate: str, pattern: str, d: ts.ParsedDirective - ) -> None: + def _validate_inner(self, to_validate: str, pattern: str, d: ts.ParsedDirective) -> None: inner = to_validate.replace(f"{pattern}", "")[1:-1].split(";") for arg in inner: match = re.fullmatch(d.regex, arg) @@ -108,9 +101,7 @@ def validate(self, directives: list[ts.ParsedDirective]) -> None: self._validate_required_directives(directives) self._validate_stencil_directives(directives) - def _validate_directive_uniqueness( - self, directives: list[ts.ParsedDirective] - ) -> None: + def _validate_directive_uniqueness(self, directives: list[ts.ParsedDirective]) -> None: """Check that all used directives are unique. Note: Allow repeated START STENCIL, END STENCIL and ENDIF directives. @@ -134,9 +125,7 @@ def _validate_directive_uniqueness( f"Error in {self.filepath}.\n Found same directive more than once in the following directives:\n {pretty_printed}" ) - def _validate_required_directives( - self, directives: list[ts.ParsedDirective] - ) -> None: + def _validate_required_directives(self, directives: list[ts.ParsedDirective]) -> None: """Check that all required directives are used at least once.""" expected = [ icon4pytools.liskov.parsing.parse.Declare, @@ -156,13 +145,9 @@ def extract_arg_from_directive(directive: str, arg: str) -> str: if match: return match.group(1) else: - raise ValueError( - f"Invalid directive string, could not find '{arg}' parameter." - ) + raise ValueError(f"Invalid directive string, could not find '{arg}' parameter.") - def _validate_stencil_directives( - self, directives: list[ts.ParsedDirective] - ) -> None: + def _validate_stencil_directives(self, directives: list[ts.ParsedDirective]) -> None: """Validate that the number of start and end stencil directives match in the input `directives`. Also verifies that each unique stencil has a corresponding start and end directive. @@ -186,14 +171,10 @@ def _validate_stencil_directives( for directive in stencil_directives: stencil_name = self.extract_arg_from_directive(directive.string, "name") stencil_counts[stencil_name] = stencil_counts.get(stencil_name, 0) + ( - 1 - if isinstance(directive, icon4pytools.liskov.parsing.parse.StartStencil) - else -1 + 1 if isinstance(directive, icon4pytools.liskov.parsing.parse.StartStencil) else -1 ) - unbalanced_stencils = [ - stencil for stencil, count in stencil_counts.items() if count != 0 - ] + unbalanced_stencils = [stencil for stencil, count in stencil_counts.items() if count != 0] if unbalanced_stencils: raise UnbalancedStencilDirectiveError( f"Error in {self.filepath}. Each unique stencil must have a corresponding START STENCIL and END STENCIL directive." diff --git a/tools/src/icon4pytools/liskov/pipeline/collection.py b/tools/src/icon4pytools/liskov/pipeline/collection.py index b71c2d2d4f..006fa8a0e9 100644 --- a/tools/src/icon4pytools/liskov/pipeline/collection.py +++ b/tools/src/icon4pytools/liskov/pipeline/collection.py @@ -12,17 +12,11 @@ # SPDX-License-Identifier: GPL-3.0-or-later from pathlib import Path -from icon4pytools.liskov.codegen.integration.deserialise import ( - IntegrationCodeDeserialiser, -) +from icon4pytools.liskov.codegen.integration.deserialise import IntegrationCodeDeserialiser from icon4pytools.liskov.codegen.integration.generate import IntegrationCodeGenerator from icon4pytools.liskov.codegen.integration.interface import IntegrationCodeInterface -from icon4pytools.liskov.codegen.serialisation.deserialise import ( - SerialisationCodeDeserialiser, -) -from icon4pytools.liskov.codegen.serialisation.generate import ( - SerialisationCodeGenerator, -) +from icon4pytools.liskov.codegen.serialisation.deserialise import SerialisationCodeDeserialiser +from icon4pytools.liskov.codegen.serialisation.generate import SerialisationCodeGenerator from icon4pytools.liskov.codegen.shared.write import CodegenWriter from icon4pytools.liskov.external.gt4py import UpdateFieldsWithGt4PyStencils from icon4pytools.liskov.parsing.parse import DirectivesParser diff --git a/tools/tests/f2ser/test_f2ser_codegen.py b/tools/tests/f2ser/test_f2ser_codegen.py index 8beedb6432..5ca607eff9 100644 --- a/tools/tests/f2ser/test_f2ser_codegen.py +++ b/tools/tests/f2ser/test_f2ser_codegen.py @@ -15,9 +15,7 @@ from icon4pytools.f2ser.deserialise import ParsedGranuleDeserialiser from icon4pytools.f2ser.parse import GranuleParser -from icon4pytools.liskov.codegen.serialisation.generate import ( - SerialisationCodeGenerator, -) +from icon4pytools.liskov.codegen.serialisation.generate import SerialisationCodeGenerator from icon4pytools.liskov.codegen.shared.types import GeneratedCode @@ -104,7 +102,5 @@ def test_deserialiser_directives_diffusion_codegen( parsed = GranuleParser(diffusion_granule, diffusion_granule_deps)() interface = ParsedGranuleDeserialiser(parsed)() generated = SerialisationCodeGenerator(interface)() - reference_savepoint = ( - samples_path / "expected_diffusion_granule_savepoint.f90" - ).read_text() + reference_savepoint = (samples_path / "expected_diffusion_granule_savepoint.f90").read_text() assert generated[0].source == reference_savepoint.rstrip() diff --git a/tools/tests/f2ser/test_granule_deserialiser.py b/tools/tests/f2ser/test_granule_deserialiser.py index 0fa352b8fa..d0f478e7ce 100644 --- a/tools/tests/f2ser/test_granule_deserialiser.py +++ b/tools/tests/f2ser/test_granule_deserialiser.py @@ -83,11 +83,7 @@ def test_deserialiser_mock(mock_parsed_granule): assert len(interface.Savepoint) == 3 assert all([isinstance(s, SavepointData) for s in interface.Savepoint]) assert all( - [ - isinstance(f, FieldSerialisationData) - for s in interface.Savepoint - for f in s.fields - ] + [isinstance(f, FieldSerialisationData) for s in interface.Savepoint for f in s.fields] ) diff --git a/tools/tests/f2ser/test_parsing.py b/tools/tests/f2ser/test_parsing.py index e2a598c07a..b0346e72f1 100644 --- a/tools/tests/f2ser/test_parsing.py +++ b/tools/tests/f2ser/test_parsing.py @@ -51,9 +51,7 @@ def test_granule_parsing(diffusion_granule, diffusion_granule_deps): def test_granule_parsing_missing_derived_typedef(diffusion_granule, samples_path): dependencies = [samples_path / "subroutine_example.f90"] parser = GranuleParser(diffusion_granule, dependencies) - with pytest.raises( - MissingDerivedTypeError, match="Could not find type definition for TYPE" - ): + with pytest.raises(MissingDerivedTypeError, match="Could not find type definition for TYPE"): parser() diff --git a/tools/tests/icon4pygen/test_backend.py b/tools/tests/icon4pygen/test_backend.py index b0fbfd9b00..e712d3ffa6 100644 --- a/tools/tests/icon4pygen/test_backend.py +++ b/tools/tests/icon4pygen/test_backend.py @@ -28,8 +28,6 @@ ) def test_missing_domain_args(input_params, expected_complement): params = [itir.Sym(id=p) for p in input_params] - domain_boundaries = set( - map(lambda s: str(s.id), GTHeader._missing_domain_params(params)) - ) + domain_boundaries = set(map(lambda s: str(s.id), GTHeader._missing_domain_params(params))) assert len(domain_boundaries) == len(expected_complement) assert domain_boundaries == set(expected_complement) diff --git a/tools/tests/icon4pygen/test_field_rendering.py b/tools/tests/icon4pygen/test_field_rendering.py index 68231f57f4..2ace9a6e24 100644 --- a/tools/tests/icon4pygen/test_field_rendering.py +++ b/tools/tests/icon4pygen/test_field_rendering.py @@ -61,9 +61,7 @@ def identity(field: Field[[EdgeDim, KDim], float]) -> Field[[EdgeDim, KDim], flo return field @program - def identity_prog( - field: Field[[EdgeDim, KDim], float], out: Field[[EdgeDim, KDim], float] - ): + def identity_prog(field: Field[[EdgeDim, KDim], float], out: Field[[EdgeDim, KDim], float]): identity(field, out=out) stencil_info = get_stencil_info(identity_prog) @@ -77,9 +75,7 @@ def identity_prog( def test_vertical_sparse_field_sid_rendering(): @field_operator - def reduction( - nb_field: Field[[EdgeDim, E2CDim, KDim], float] - ) -> Field[[EdgeDim, KDim], float]: + def reduction(nb_field: Field[[EdgeDim, E2CDim, KDim], float]) -> Field[[EdgeDim, KDim], float]: return neighbor_sum(nb_field, axis=E2CDim) @program diff --git a/tools/tests/liskov/test_directives_deserialiser.py b/tools/tests/liskov/test_directives_deserialiser.py index b1de1b72ba..13a235f6f5 100644 --- a/tools/tests/liskov/test_directives_deserialiser.py +++ b/tools/tests/liskov/test_directives_deserialiser.py @@ -85,9 +85,7 @@ ), ], ) -def test_data_factories_no_args( - factory_class, directive_type, string, startln, endln, expected -): +def test_data_factories_no_args(factory_class, directive_type, string, startln, endln, expected): parsed = { "directives": [directive_type(string=string, startln=startln, endln=endln)], "content": {}, @@ -111,9 +109,7 @@ def test_data_factories_no_args( { "directives": [ ts.EndStencil("END STENCIL(name=foo)", 5, 5), - ts.EndStencil( - "END STENCIL(name=bar; noendif=true; noprofile=true)", 20, 20 - ), + ts.EndStencil("END STENCIL(name=bar; noendif=true; noprofile=true)", 20, 20), ], "content": { "EndStencil": [ @@ -127,9 +123,7 @@ def test_data_factories_no_args( EndStencilDataFactory, EndStencilData, { - "directives": [ - ts.EndStencil("END STENCIL(name=foo; noprofile=true)", 5, 5) - ], + "directives": [ts.EndStencil("END STENCIL(name=foo; noprofile=true)", 5, 5)], "content": {"EndStencil": [{"name": "foo"}]}, }, ), @@ -203,9 +197,7 @@ def test_data_factories_with_args(factory, target, mock_data): ), ( { - "directives": [ - ts.StartCreate("START CREATE(extra_fields=foo,xyz)", 5, 5) - ], + "directives": [ts.StartCreate("START CREATE(extra_fields=foo,xyz)", 5, 5)], "content": {"StartCreate": [{"extra_fields": "foo,xyz"}]}, }, ["foo", "xyz"], @@ -239,9 +231,7 @@ def test_start_create_factory(mock_data, extra_fields): ts.EndStencil("END STENCIL(name=foo)", 5, 5), ts.EndStencil("END STENCIL(name=bar; noendif=foo)", 20, 20), ], - "content": { - "EndStencil": [{"name": "foo"}, {"name": "bar", "noendif": "foo"}] - }, + "content": {"EndStencil": [{"name": "foo"}, {"name": "bar", "noendif": "foo"}]}, }, ), ], @@ -324,10 +314,7 @@ def test_update_field_tolerances(self): FieldAssociationData("x", "i", 3, rel_tol="0.01", abs_tol="0.1"), FieldAssociationData("y", "i", 3, rel_tol="0.001"), ] - assert ( - self.factory._update_tolerances(named_args, self.mock_fields) - == expected_fields - ) + assert self.factory._update_tolerances(named_args, self.mock_fields) == expected_fields def test_update_field_tolerances_not_all_fields(self): # Test that tolerance is not set for fields that are not provided in the named_args. @@ -339,15 +326,9 @@ def test_update_field_tolerances_not_all_fields(self): FieldAssociationData("x", "i", 3, rel_tol="0.01", abs_tol="0.1"), FieldAssociationData("y", "i", 3), ] - assert ( - self.factory._update_tolerances(named_args, self.mock_fields) - == expected_fields - ) + assert self.factory._update_tolerances(named_args, self.mock_fields) == expected_fields def test_update_field_tolerances_no_tolerances(self): # Test that fields are not updated if named_args does not contain any tolerances. named_args = {} - assert ( - self.factory._update_tolerances(named_args, self.mock_fields) - == self.mock_fields - ) + assert self.factory._update_tolerances(named_args, self.mock_fields) == self.mock_fields diff --git a/tools/tests/liskov/test_external.py b/tools/tests/liskov/test_external.py index e912a09932..1e0de1b5ca 100644 --- a/tools/tests/liskov/test_external.py +++ b/tools/tests/liskov/test_external.py @@ -22,10 +22,7 @@ IntegrationCodeInterface, StartStencilData, ) -from icon4pytools.liskov.external.exceptions import ( - IncompatibleFieldError, - UnknownStencilError, -) +from icon4pytools.liskov.external.exceptions import IncompatibleFieldError, UnknownStencilError from icon4pytools.liskov.external.gt4py import UpdateFieldsWithGt4PyStencils diff --git a/tools/tests/liskov/test_generation.py b/tools/tests/liskov/test_generation.py index 03bd4d5d49..dc16b09c6a 100644 --- a/tools/tests/liskov/test_generation.py +++ b/tools/tests/liskov/test_generation.py @@ -31,9 +31,7 @@ ) # TODO: fix tests to adapt to new custom output fields -from icon4pytools.liskov.codegen.serialisation.generate import ( - SerialisationCodeGenerator, -) +from icon4pytools.liskov.codegen.serialisation.generate import SerialisationCodeGenerator from icon4pytools.liskov.codegen.serialisation.interface import ( FieldSerialisationData, ImportData, @@ -51,9 +49,7 @@ def integration_code_interface(): fields=[ FieldAssociationData("scalar1", "scalar1", inp=True, out=False, dims=None), FieldAssociationData("inp1", "inp1(:,:,1)", inp=True, out=False, dims=2), - FieldAssociationData( - "out1", "out1(:,:,1)", inp=False, out=True, dims=2, abs_tol="0.5" - ), + FieldAssociationData("out1", "out1(:,:,1)", inp=False, out=True, dims=2, abs_tol="0.5"), FieldAssociationData( "out2", "p_nh%prog(nnew)%out2(:,:,1)", @@ -62,12 +58,8 @@ def integration_code_interface(): dims=3, abs_tol="0.2", ), - FieldAssociationData( - "out3", "p_nh%prog(nnew)%w(:,:,jb)", inp=False, out=True, dims=2 - ), - FieldAssociationData( - "out4", "p_nh%prog(nnew)%w(:,:,1,2)", inp=False, out=True, dims=3 - ), + FieldAssociationData("out3", "p_nh%prog(nnew)%w(:,:,jb)", inp=False, out=True, dims=2), + FieldAssociationData("out4", "p_nh%prog(nnew)%w(:,:,1,2)", inp=False, out=True, dims=3), FieldAssociationData( "out5", "p_nh%prog(nnew)%w(:,:,:,ntnd)", inp=False, out=True, dims=3 ), @@ -214,9 +206,7 @@ def expected_insert_source(): @pytest.fixture def integration_code_generator(integration_code_interface): - return IntegrationCodeGenerator( - integration_code_interface, profile=True, metadatagen=False - ) + return IntegrationCodeGenerator(integration_code_interface, profile=True, metadatagen=False) def test_integration_code_generation( @@ -337,9 +327,7 @@ def expected_savepoints(): def test_serialisation_code_generation( serialisation_code_interface, expected_savepoints, multinode ): - generated = SerialisationCodeGenerator( - serialisation_code_interface, multinode=multinode - )() + generated = SerialisationCodeGenerator(serialisation_code_interface, multinode=multinode)() if multinode: assert len(generated) == 3 diff --git a/tools/tests/liskov/test_serialisation_deserialiser.py b/tools/tests/liskov/test_serialisation_deserialiser.py index 380b8c14b9..0431086beb 100644 --- a/tools/tests/liskov/test_serialisation_deserialiser.py +++ b/tools/tests/liskov/test_serialisation_deserialiser.py @@ -109,9 +109,7 @@ def parsed_dict(): ], "StartProfile": [{"name": "apply_nabla2_to_vn_in_lateral_boundary"}], "EndProfile": [{}], - "EndStencil": [ - {"name": "apply_nabla2_to_vn_in_lateral_boundary", "noprofile": "True"} - ], + "EndStencil": [{"name": "apply_nabla2_to_vn_in_lateral_boundary", "noprofile": "True"}], "EndCreate": [{}], }, } diff --git a/tools/tests/liskov/test_validation.py b/tools/tests/liskov/test_validation.py index e02c08c46c..d6fbc4f927 100644 --- a/tools/tests/liskov/test_validation.py +++ b/tools/tests/liskov/test_validation.py @@ -20,12 +20,7 @@ RequiredDirectivesError, UnbalancedStencilDirectiveError, ) -from icon4pytools.liskov.parsing.parse import ( - Declare, - DirectivesParser, - Imports, - StartStencil, -) +from icon4pytools.liskov.parsing.parse import Declare, DirectivesParser, Imports, StartStencil from icon4pytools.liskov.parsing.validation import DirectiveSyntaxValidator from .conftest import insert_new_lines, scan_for_directives @@ -77,9 +72,7 @@ def test_directive_syntax_validator(directive): "!$DSL IMPORTS()", ], ) -def test_directive_semantics_validation_repeated_directives( - make_f90_tmpfile, directive -): +def test_directive_semantics_validation_repeated_directives(make_f90_tmpfile, directive): fpath = make_f90_tmpfile(content=SINGLE_STENCIL) opath = fpath.with_suffix(".gen") insert_new_lines(fpath, [directive]) @@ -115,9 +108,7 @@ def test_directive_semantics_validation_repeated_stencil(make_f90_tmpfile, direc """!$DSL END STENCIL(name=apply_nabla2_to_vn_in_lateral_boundary; noprofile=True)""", ], ) -def test_directive_semantics_validation_required_directives( - make_f90_tmpfile, directive -): +def test_directive_semantics_validation_required_directives(make_f90_tmpfile, directive): new = SINGLE_STENCIL.replace(directive, "") fpath = make_f90_tmpfile(content=new) opath = fpath.with_suffix(".gen") diff --git a/tools/tests/liskov/test_writer.py b/tools/tests/liskov/test_writer.py index ed8cb512c0..e24410120e 100644 --- a/tools/tests/liskov/test_writer.py +++ b/tools/tests/liskov/test_writer.py @@ -64,9 +64,7 @@ def test_insert_generated_code(): "another line", "generated code2\n", ] - assert ( - CodegenWriter._insert_generated_code(current_file, generated) == expected_output - ) + assert CodegenWriter._insert_generated_code(current_file, generated) == expected_output def test_write_file():