Skip to content

Commit

Permalink
Add generated version of MPI_Recv_f08
Browse files Browse the repository at this point in the history
  • Loading branch information
jtronge committed Dec 20, 2023
1 parent 903186d commit 1c7c37f
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 48 deletions.
4 changes: 2 additions & 2 deletions ompi/mpi/fortran/use-mpi-f08/Makefile.am
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ psizeof_f08.f90:
CLEANFILES += sizeof_f08.h sizeof_f08.f90 psizeof_f08.f90

mpi_api_generated_files = \
send_f08_generated.F90
send_f08_generated.F90 \
recv_f08_generated.F90
mpi_api_files = \
abort_f08.F90 \
accumulate_f08.F90 \
Expand Down Expand Up @@ -380,7 +381,6 @@ mpi_api_files = \
put_f08.F90 \
query_thread_f08.F90 \
raccumulate_f08.F90 \
recv_f08.F90 \
recv_init_f08.F90 \
reduce_f08.F90 \
reduce_init_f08.F90 \
Expand Down
3 changes: 2 additions & 1 deletion ompi/mpi/fortran/use-mpi-f08/base/Makefile.am
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ noinst_LTLIBRARIES = libusempif08_ccode.la

libusempif08_ccode_la_SOURCES = \
buffer_detach.c \
send_f08_generated.c
send_f08_generated.c \
recv_f08_generated.c

if OMPI_GENERATE_BINDINGS
%_generated.c: ../%.in $(srcdir)/../generate_bindings.py
Expand Down
73 changes: 57 additions & 16 deletions ompi/mpi/fortran/use-mpi-f08/generate_bindings.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,19 @@
"""Fortran binding generation code.
This takes as input a *.in file containing the prototype of a Fortran function
with generic types. Both the Fortran subroutine and a wrapping C function can
generated from this file.
"""
from abc import ABC, abstractmethod
import argparse
import re


C_ERROR_TEMP_NAME = 'c_ierr'
GENERATED_MESSAGE = 'THIS FILE WAS AUTOMATICALLY GENERATED. DO NOT EDIT BY HAND.'
PROTOTYPE_RE = re.compile(r'^\w+\((\s*\w+\s+\w+\s*,?)+\)$')


class FortranType(ABC):

def __init__(self, name, **kwargs):
Expand Down Expand Up @@ -56,6 +67,12 @@ def c_post(self):
return []


#
# Definitions of generic types in Fortran and how these can be converted
# to and from C.
#


@FortranType.add('BUFFER')
class BufferType(FortranType):
def declare(self):
Expand Down Expand Up @@ -84,11 +101,16 @@ def c_argument(self):
return f'*{self.name}' if self.bigcount else f'OMPI_FINT_2_INT(*{self.name})'


def tmp_c_type(name):
def tmp_c_name(name):
"""Return a temporary name for use in C."""
return f'c_{name}'


def tmp_c_name2(name):
"""Return a secondary temporary name for use in C."""
return f'c_{name}2'


@FortranType.add('DATATYPE')
class DatatypeType(FortranType):
def declare(self):
Expand All @@ -107,10 +129,10 @@ def c_parameter(self):
return f'MPI_Fint *{self.name}'

def c_prepare(self):
return [f'MPI_Datatype {tmp_c_type(self.name)} = PMPI_Type_f2c(*{self.name});']
return [f'MPI_Datatype {tmp_c_name(self.name)} = PMPI_Type_f2c(*{self.name});']

def c_argument(self):
return tmp_c_type(self.name)
return tmp_c_name(self.name)


class IntType(FortranType):
Expand Down Expand Up @@ -152,13 +174,35 @@ def c_parameter(self):
return f'MPI_Fint *{self.name}'

def c_prepare(self):
return [f'MPI_Comm {tmp_c_type(self.name)} = PMPI_Comm_f2c(*{self.name});']
return [f'MPI_Comm {tmp_c_name(self.name)} = PMPI_Comm_f2c(*{self.name});']

def c_argument(self):
return tmp_c_type(self.name)
return tmp_c_name(self.name)


PROTOTYPE_RE = re.compile(r'^\w+\((\s*\w+\s+\w+\s*,?)+\)$')
@FortranType.add('STATUS')
class StatusType(FortranType):
def declare(self):
return f'TYPE(MPI_Status), INTENT(OUT) :: {self.name}'

def use(self):
return [('mpi_f08_types', 'MPI_Status')]

def c_parameter(self):
# TODO: Is this correct? (I've listed it as TYPE(MPI_Status) in the binding)
return f'MPI_Fint *{self.name}'

def c_prepare(self):
return [
f'OMPI_FORTRAN_STATUS_DECLARATION({tmp_c_name(self.name)}, {tmp_c_name2(self.name)});',
f'OMPI_FORTRAN_STATUS_SET_POINTER({tmp_c_name(self.name)}, {tmp_c_name2(self.name)}, {self.name});'
]

def c_argument(self):
return tmp_c_name(self.name)

def c_post(self):
return [f'OMPI_FORTRAN_STATUS_RETURN({tmp_c_name(self.name)}, {tmp_c_name2(self.name)}, {self.name}, {C_ERROR_TEMP_NAME});']


class PrototypeParseError(Exception):
Expand Down Expand Up @@ -186,10 +230,6 @@ def print_header():
print('#include "mpi-f08-rename.h"')



GENERATED_MESSAGE = 'THIS FILE WAS AUTOMATICALLY GENERATED. DO NOT EDIT BY HAND.'


class FortranBinding:

def __init__(self, fname):
Expand Down Expand Up @@ -272,7 +312,7 @@ def print_f_source(self):
# Add the integer error manually
print(' INTEGER, OPTIONAL, INTENT(OUT) :: ierror')
# Temporaries
print(' INTEGER :: c_ierror')
print(f' INTEGER :: {C_ERROR_TEMP_NAME}')

# Interface for call to C function
print()
Expand All @@ -281,9 +321,9 @@ def print_f_source(self):

# Call into the C function
args = ','.join(param.argument() for param in self.parameters)
print(f' call {c_func}({args},c_ierror)')
print(f' call {c_func}({args},{C_ERROR_TEMP_NAME})')
# Convert error type
print(' if (present(ierror)) ierror = c_ierror')
print(f' if (present(ierror)) ierror = {C_ERROR_TEMP_NAME}')

print(f'end subroutine {sub_name}')

Expand All @@ -292,6 +332,7 @@ def print_c_source(self):
print(f'/* {GENERATED_MESSAGE} */')
print('#include "ompi_config.h"')
print('#include "mpi.h"')
print('#include "ompi/mpi/fortran/mpif-h/status-conversion.h"')
print('#include "ompi/mpi/fortran/base/constants.h"')
print('#include "ompi/mpi/fortran/base/fint_2_int.h"')
c_func = c_func_name(self.fn_name)
Expand All @@ -303,19 +344,19 @@ def print_c_source(self):
print(f'void {c_func}({parameters});')
print(f'void {c_func}({parameters})')
print('{')
print(' int c_ierr; ')
print(f' int {C_ERROR_TEMP_NAME}; ')
for param in self.parameters:
for line in param.c_prepare():
print(f' {line}')
c_api_func = c_api_func_name(self.fn_name)
arguments = [param.c_argument() for param in self.parameters]
arguments = ', '.join(arguments)
print(f' c_ierr = {c_api_func}({arguments});')
print(f' {C_ERROR_TEMP_NAME} = {c_api_func}({arguments});')
for param in self.parameters:
for line in param.c_post():
print(f' {line}')
# TODO: Is this NULL check necessary for mpi_f08?
print(' if (NULL != ierr) *ierr = OMPI_INT_2_FINT(c_ierr);')
print(f' if (NULL != ierr) *ierr = OMPI_INT_2_FINT({C_ERROR_TEMP_NAME});')
print('}')


Expand Down
29 changes: 0 additions & 29 deletions ompi/mpi/fortran/use-mpi-f08/recv_f08.F90

This file was deleted.

1 change: 1 addition & 0 deletions ompi/mpi/fortran/use-mpi-f08/recv_f08.in
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
recv(BUFFER buf, COUNT count, DATATYPE datatype, RANK source, TAG tag, COMM comm, STATUS status)

0 comments on commit 1c7c37f

Please sign in to comment.