Skip to content

Commit

Permalink
Merge branch 'main' into multi-volume
Browse files Browse the repository at this point in the history
  • Loading branch information
majosm authored Mar 25, 2022
2 parents 85d9812 + 7a8dd5d commit e6e190b
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 15 deletions.
2 changes: 1 addition & 1 deletion examples/old_symbolics/dagrt-fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def isolate_function_calls_in_phase(phase, stmt_id_gen, var_name_gen):
stmt_id_gen=stmt_id_gen,
var_name_gen=var_name_gen)

for stmt in sorted(phase.statements, key=lambda stmt: stmt.id):
for stmt in sorted(phase.statements, key=lambda stmt_: stmt_.id):
new_deps = []

from dagrt.language import Assign
Expand Down
27 changes: 27 additions & 0 deletions grudge/array_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,19 @@ class PyOpenCLArrayContext(_PyOpenCLArrayContextBase):
to understand :mod:`grudge`-specific transform metadata. (Of which there isn't
any, for now.)
"""
def __init__(self, queue: "pyopencl.CommandQueue",
allocator: Optional["pyopencl.tools.AllocatorInterface"] = None,
wait_event_queue_length: Optional[int] = None,
force_device_scalars: bool = False) -> None:

if allocator is None:
from warnings import warn
warn("No memory allocator specified, please pass one. "
"(Preferably a pyopencl.tools.MemoryPool in order "
"to reduce device allocations)")

super().__init__(queue, allocator,
wait_event_queue_length, force_device_scalars)

# }}}

Expand All @@ -99,6 +112,13 @@ class PytatoPyOpenCLArrayContext(_PytatoPyOpenCLArrayContextBase):
Extends it to understand :mod:`grudge`-specific transform metadata. (Of
which there isn't any, for now.)
"""
def __init__(self, queue, allocator=None):
if allocator is None:
from warnings import warn
warn("No memory allocator specified, please pass one. "
"(Preferably a pyopencl.tools.MemoryPool in order "
"to reduce device allocations)")
super().__init__(queue, allocator)

# }}}

Expand Down Expand Up @@ -210,6 +230,7 @@ def __call__(self, arg_id_to_arg) -> ArrayContainer:
out_dict = execute_distributed_partition(
self.distributed_partition, self.part_id_to_prg,
self.actx.queue, self.actx.mpi_communicator,
allocator=self.actx.allocator,
input_args=input_args_for_prg)

def to_output_template(keys, _):
Expand All @@ -224,6 +245,12 @@ class MPIPytatoArrayContextBase(MPIBasedArrayContext):
def __init__(
self, mpi_communicator, queue, *, mpi_base_tag, allocator=None
) -> None:
if allocator is None:
from warnings import warn
warn("No memory allocator specified, please pass one. "
"(Preferably a pyopencl.tools.MemoryPool in order "
"to reduce device allocations)")

super().__init__(queue, allocator)

self.mpi_communicator = mpi_communicator
Expand Down
30 changes: 16 additions & 14 deletions test/test_mpi_communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,22 +151,24 @@ def _test_func_comparison_mpi_communication_entrypoint(actx):
bdry_faces_func = op.project(dcoll, BTAG_ALL, dd_af,
op.project(dcoll, dd_vol, BTAG_ALL, myfunc))

hopefully_zero = (
op.project(
dcoll, "int_faces", "all_faces",
dcoll.opposite_face_connection(
dof_desc.BoundaryDomainTag(
dof_desc.FACE_RESTR_INTERIOR, dof_desc.VTAG_ALL)
)(int_faces_func)
)
+ sum(op.project(dcoll, tpair.dd, "all_faces", tpair.int)
for tpair in op.cross_rank_trace_pairs(dcoll, myfunc,
comm_tag=SimpleTag))
) - (all_faces_func - bdry_faces_func)
def hopefully_zero():
return (
op.project(
dcoll, "int_faces", "all_faces",
dcoll.opposite_face_connection(
dof_desc.BoundaryDomainTag(
dof_desc.FACE_RESTR_INTERIOR, dof_desc.VTAG_ALL)
)(int_faces_func)
)
+ sum(op.project(dcoll, tpair.dd, "all_faces", tpair.ext)
for tpair in op.cross_rank_trace_pairs(dcoll, myfunc,
comm_tag=SimpleTag))
) - (all_faces_func - bdry_faces_func)

hopefully_zero_result = actx.compile(hopefully_zero)()

error = actx.to_numpy(flat_norm(hopefully_zero, ord=np.inf))
error = actx.to_numpy(flat_norm(hopefully_zero_result, ord=np.inf))

print(__file__)
with np.printoptions(threshold=100000000, suppress=True):
logger.debug(hopefully_zero)
logger.info("error: %.5e", error)
Expand Down

0 comments on commit e6e190b

Please sign in to comment.