Skip to content

Commit

Permalink
Add Buffer type, improve Datatype handling
Browse files Browse the repository at this point in the history
This contains two related changes:

1. Defines a specific `Buffer` type, which contains the reference to the storage buffer, its count and datatype. This allows us to simplify the type signatures of various functions, as `count` and `datatype` no longer need to be arguments to the functions. This also adds default conversion methods for `Array`s and `Subarray`s (creating the derived datatypes where necessary, and determining the appropriate `count`s), and moves the point-to-point operations to use these conversions.

2. Improves the handling of `Datatype` handles, by making them garbage-collected objects (like other MPI handles), moves lower-level functions to a submodule, defines consistent interfaces. Also fixes #327.

I still need to move the collective calls over as well, however that will require more thought on how to handle the "chunked" operations like scatter/gather.

I also removed the inverse dictionary mappings from MPI Datatype -> Julia Type, as that is no longer so easy to determine.
  • Loading branch information
simonbyrne committed Jan 2, 2020
1 parent ef5cfee commit f972016
Show file tree
Hide file tree
Showing 16 changed files with 748 additions and 431 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,16 @@ version = "0.11.0"

[deps]
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
Sockets = "6462fe0b-24de-5631-8697-dd941f90decc"

[compat]
julia = "1"
Requires = "~0.5"
julia = "1"

[extras]
DoubleFloats = "497a8b3b-efae-58df-a0af-a86822472b78"
Expand Down
4 changes: 4 additions & 0 deletions deps/consts_msmpi.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# From https://github.com/microsoft/Microsoft-MPI/blob/v10.0/src/include/mpi.h

const MPI_Aint = Int
const MPI_Offset = Int64
const MPI_Count = Int64

for T in [:MPI_Comm, :MPI_Info, :MPI_Win, :MPI_Request, :MPI_Op, :MPI_Datatype]
@eval begin
primitive type $T 32 end
Expand Down
4 changes: 4 additions & 0 deletions deps/gen_consts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,10 @@ int main(int argc, char *argv[]) {
fprintf(fptr, "# Do not edit\\n");
""")

println(f," fprintf(fptr, \"const MPI_Aint = Int%d\\n\", 8*(int)sizeof(MPI_Aint));")
println(f," fprintf(fptr, \"const MPI_Offset = Int%d\\n\", 8*(int)sizeof(MPI_Offset));")
println(f," fprintf(fptr, \"const MPI_Count = Int%d\\n\", 8*(int)sizeof(MPI_Count));")

println(f," fprintf(fptr, \"const MPI_Status_size = %d\\n\", (int)sizeof(MPI_Status));")
println(f," fprintf(fptr, \"const MPI_Status_Source_offset = %d\\n\", (int)offsetof(MPI_Status, MPI_SOURCE));")
println(f," fprintf(fptr, \"const MPI_Status_Tag_offset = %d\\n\", (int)offsetof(MPI_Status, MPI_TAG));")
Expand Down
18 changes: 16 additions & 2 deletions docs/src/advanced.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,25 @@ MPI.refcount_inc
MPI.refcount_dec
```

## Buffers

```@docs
MPI.Buffer
MPI.Buffer_send
MPI.MPIPtr
```

## Datatype objects

```@docs
MPI.mpitype
MPI.Type_Create_Subarray
MPI.Datatype
MPI.Types.extent
MPI.Types.create_contiguous
MPI.Types.create_vector
MPI.Types.create_subarray
MPI.Types.create_struct
MPI.Types.create_resized
MPI.Types.commit!
```

## Operator objects
Expand Down
3 changes: 3 additions & 0 deletions src/MPI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module MPI

using Libdl, Serialization
using Requires
using DocStringExtensions

macro mpichk(expr)
@assert expr isa Expr && expr.head == :call && expr.args[1] == :ccall
Expand Down Expand Up @@ -38,11 +39,13 @@ function _doc_external(fname)
end

include(joinpath(@__DIR__, "..", "deps", "deps.jl"))

include("handle.jl")
include("info.jl")
include("comm.jl")
include("environment.jl")
include("datatypes.jl")
include("buffers.jl")
include("operators.jl")
include("pointtopoint.jl")
include("collective.jl")
Expand Down
127 changes: 127 additions & 0 deletions src/buffers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
const MPIInteger = Union{Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64, UInt64}
const MPIFloatingPoint = Union{Float32, Float64}
const MPIComplex = Union{ComplexF32, ComplexF64}

const MPIDatatype = Union{Char,
Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64,
UInt64,
Float32, Float64, ComplexF32, ComplexF64}
MPIBuffertype{T} = Union{Ptr{T}, Array{T}, SubArray{T}, Ref{T}}

MPIBuffertypeOrConst{T} = Union{MPIBuffertype{T}, SentinelPtr}

Base.cconvert(::Type{MPIPtr}, x::Union{Ptr{T}, Array{T}, Ref{T}}) where T = Base.cconvert(Ptr{T}, x)
function Base.cconvert(::Type{MPIPtr}, x::SubArray{T}) where T
Base.cconvert(Ptr{T}, x)
end
function Base.unsafe_convert(::Type{MPIPtr}, x::MPIBuffertype{T}) where T
ptr = Base.unsafe_convert(Ptr{T}, x)
reinterpret(MPIPtr, ptr)
end
function Base.cconvert(::Type{MPIPtr}, ::Nothing)
reinterpret(MPIPtr, C_NULL)
end

macro assert_minlength(buffer, count)
quote
if $(esc(buffer)) isa AbstractArray
@assert length($(esc(buffer))) >= $(esc(count))
end
end
end

"""
MPI.MPIPtr
A pointer to an MPI buffer. This type is used only as part of the implicit conversion in
`ccall`: a Julia object can be passed to MPI by defining methods for
`Base.cconvert(::Type{MPIPtr}, ...)`/`Base.unsafe_convert(::Type{MPIPtr}, ...)`.
Currently supported are:
- `Ptr`
- `Ref`
- `Array`
- `SubArray`
- `CuArray` if CuArrays.jl is loaded.
Additionally, certain sentinel values can be used, e.g. `MPI_IN_PLACE` or `MPI_BOTTOM`.
"""
MPIPtr


"""
MPI.Buffer
An MPI buffer for communication operations.
# Fields
$(DocStringExtensions.FIELDS)
# Usage
Buffer(data, count::Integer, datatype::Datatype)
Generic constructor.
Buffer(data)
Construct a `Buffer` backed by `data`, automatically determining the appropriate `count`
and `datatype`. Methods are provided for
- `Ref`
- `Array`
- `CuArray` if CuArrays.jl is loaded
- `SubArray`s of an `Array` or `CuArray` where the layout is contiguous, sequential or
blocked.
"""
struct Buffer{A}
"""a Julia object referencing a region of memory to be used for communication. It is
required that the object can be `cconvert`ed to an [`MPIPtr`](@ref)."""
data::A

"""the number of elements of `datatype` in the buffer. Note that this may not
correspond to the number of elements in the array if derived types are used."""
count::Cint

"""the [`MPI.Datatype`](@ref) stored in the buffer."""
datatype::Datatype
end
Buffer(buf::Buffer) = buf
Buffer(data, count::Integer, datatype::Datatype) = Buffer(data, Cint(count), datatype)

function Buffer(arr::Array)
Buffer(arr, Cint(length(arr)), Datatype(eltype(arr)))
end
function Buffer(ref::Ref)
Buffer(ref, Cint(1), Datatype(eltype(ref)))
end

# SubArray
function Buffer(sub::Base.FastContiguousSubArray)
Buffer(sub, Cint(length(sub)), Datatype(eltype(sub)))
end
function Buffer(sub::Base.FastSubArray)
datatype = Types.create_vector(length(sub), 1, sub.stride1,
Datatype(eltype(sub); commit=false))
Types.commit!(datatype)
Buffer(sub, Cint(1), datatype)
end
function Buffer(sub::SubArray{T,N,P,I,false}) where {T,N,P,I<:Tuple{Vararg{Union{Base.ScalarIndex, Base.Slice, AbstractUnitRange}}}}
datatype = Types.create_subarray(size(parent(sub)),
map(length, sub.indices),
map(i -> first(i)-1, sub.indices),
Datatype(eltype(sub), commit=false))
Types.commit!(datatype)
Buffer(parent(sub), Cint(1), datatype)
end

"""
Buffer_send(data)
Construct a [`Buffer`](@ref) object for a send operation from `data`, allowing cases where
`isbits(data)`.
"""
Buffer_send(data) = isbits(data) ? Buffer(Ref(data)) : Buffer(data)

const BUFFER_NULL = Buffer(C_NULL, 0, DATATYPE_NULL)
26 changes: 13 additions & 13 deletions src/collective.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ function Bcast!(buffer, count::Integer,
# MPI_Comm comm)
@mpichk ccall((:MPI_Bcast, libmpi), Cint,
(MPIPtr, Cint, MPI_Datatype, Cint, MPI_Comm),
buffer, count, mpitype(eltype(buffer)), root, comm)
buffer, count, Datatype(eltype(buffer)), root, comm)
buffer
end

Expand Down Expand Up @@ -105,7 +105,7 @@ function Scatter!(sendbuf, recvbuf, count::Integer, root::Integer, comm::Comm)
# MPI_Comm comm)
@mpichk ccall((:MPI_Scatter, libmpi), Cint,
(MPIPtr, Cint, MPI_Datatype, MPIPtr, Cint, MPI_Datatype, Cint, MPI_Comm),
sendbuf, count, mpitype(T), recvbuf, count, mpitype(T), root, comm)
sendbuf, count, Datatype(T), recvbuf, count, Datatype(T), root, comm)
recvbuf
end

Expand Down Expand Up @@ -174,7 +174,7 @@ function Scatterv!(sendbuf, recvbuf, counts::Vector, root::Integer, comm::Comm)
# int recvcount, MPI_Datatype recvtype, int root, MPI_Comm comm)
@mpichk ccall((:MPI_Scatterv, libmpi), Cint,
(MPIPtr, Ptr{Cint}, Ptr{Cint}, MPI_Datatype, MPIPtr, Cint, MPI_Datatype, Cint, MPI_Comm),
sendbuf, counts, disps, mpitype(T), recvbuf, recvcnt, mpitype(T), root, comm)
sendbuf, counts, disps, Datatype(T), recvbuf, recvcnt, Datatype(T), root, comm)
recvbuf
end

Expand Down Expand Up @@ -245,7 +245,7 @@ function Gather!(sendbuf, recvbuf, count::Integer, root::Integer, comm::Comm)
# MPI_Comm comm)
@mpichk ccall((:MPI_Gather, libmpi), Cint,
(MPIPtr, Cint, MPI_Datatype, MPIPtr, Cint, MPI_Datatype, Cint, MPI_Comm),
sendbuf, count, mpitype(T), recvbuf, count, mpitype(T), root, comm)
sendbuf, count, Datatype(T), recvbuf, count, Datatype(T), root, comm)
isroot ? recvbuf : nothing
end
function Gather!(sendbuf, recvbuf, root::Integer, comm::Comm)
Expand Down Expand Up @@ -305,7 +305,7 @@ function Allgather!(sendbuf, recvbuf, count::Integer, comm::Comm)
# MPI_Datatype recvtype, MPI_Comm comm)
@mpichk ccall((:MPI_Allgather, libmpi), Cint,
(MPIPtr, Cint, MPI_Datatype, MPIPtr, Cint, MPI_Datatype, MPI_Comm),
sendbuf, count, mpitype(T), recvbuf, count, mpitype(T), comm)
sendbuf, count, Datatype(T), recvbuf, count, Datatype(T), comm)
recvbuf
end
function Allgather!(sendrecvbuf, count::Integer, comm::Comm)
Expand Down Expand Up @@ -381,7 +381,7 @@ function Gatherv!(sendbuf, recvbuf, counts::Vector{Cint}, root::Integer, comm::C
# MPI_Datatype recvtype, int root, MPI_Comm comm)
@mpichk ccall((:MPI_Gatherv, libmpi), Cint,
(MPIPtr, Cint, MPI_Datatype, MPIPtr, Ptr{Cint}, Ptr{Cint}, MPI_Datatype, Cint, MPI_Comm),
sendbuf, sendcnt, mpitype(T), recvbuf, counts, displs, mpitype(T), root, comm)
sendbuf, sendcnt, Datatype(T), recvbuf, counts, displs, Datatype(T), root, comm)
isroot ? recvbuf : nothing
end

Expand Down Expand Up @@ -436,7 +436,7 @@ function Allgatherv!(sendbuf, recvbuf, counts::Vector{Cint}, comm::Comm)
# const int displs[], MPI_Datatype recvtype, MPI_Comm comm)
@mpichk ccall((:MPI_Allgatherv, libmpi), Cint,
(MPIPtr, Cint, MPI_Datatype, MPIPtr, Ptr{Cint}, Ptr{Cint}, MPI_Datatype, MPI_Comm),
sendbuf, sendcnt, mpitype(T), recvbuf, counts, displs, mpitype(T), comm)
sendbuf, sendcnt, Datatype(T), recvbuf, counts, displs, Datatype(T), comm)
recvbuf
end
function Allgatherv!(sendrecvbuf, counts::Vector{Cint}, comm::Comm)
Expand Down Expand Up @@ -499,7 +499,7 @@ function Alltoall!(sendbuf, recvbuf, count::Integer, comm::Comm)
# MPI_Comm comm)
@mpichk ccall((:MPI_Alltoall, libmpi), Cint,
(MPIPtr, Cint, MPI_Datatype, MPIPtr, Cint, MPI_Datatype, MPI_Comm),
sendbuf, count, mpitype(T), recvbuf, count, mpitype(T), comm)
sendbuf, count, Datatype(T), recvbuf, count, Datatype(T), comm)
recvbuf
end
function Alltoall!(sendrecvbuf, count::Integer, comm::Comm)
Expand Down Expand Up @@ -558,7 +558,7 @@ function Alltoallv!(sendbuf, recvbuf, scounts::Vector{Cint}, rcounts::Vector{Cin
# MPI_Datatype recvtype, MPI_Comm comm)
@mpichk ccall((:MPI_Alltoallv, libmpi), Cint,
(MPIPtr, Ptr{Cint}, Ptr{Cint}, MPI_Datatype, MPIPtr, Ptr{Cint}, Ptr{Cint}, MPI_Datatype, MPI_Comm),
sendbuf, scounts, sdispls, mpitype(T), recvbuf, rcounts, rdispls, mpitype(T), comm)
sendbuf, scounts, sdispls, Datatype(T), recvbuf, rcounts, rdispls, Datatype(T), comm)
recvbuf
end

Expand Down Expand Up @@ -616,7 +616,7 @@ function Reduce!(sendbuf, recvbuf, count::Integer, op::Union{Op,MPI_Op}, root::I
# MPI_Datatype datatype, MPI_Op op, int root, MPI_Comm comm)
@mpichk ccall((:MPI_Reduce, libmpi), Cint,
(MPIPtr, MPIPtr, Cint, MPI_Datatype, MPI_Op, Cint, MPI_Comm),
sendbuf, recvbuf, count, mpitype(T), op, root, comm)
sendbuf, recvbuf, count, Datatype(T), op, root, comm)
recvbuf
end

Expand Down Expand Up @@ -699,7 +699,7 @@ function Allreduce!(sendbuf, recvbuf, count::Integer, op::Union{Op,MPI_Op}, comm
# MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
@mpichk ccall((:MPI_Allreduce, libmpi), Cint,
(MPIPtr, MPIPtr, Cint, MPI_Datatype, MPI_Op, MPI_Comm),
sendbuf, recvbuf, count, mpitype(T), op, comm)
sendbuf, recvbuf, count, Datatype(T), op, comm)
recvbuf
end
function Allreduce!(sendbuf, recvbuf, count::Integer, opfunc, comm::Comm)
Expand Down Expand Up @@ -766,7 +766,7 @@ function Scan!(sendbuf, recvbuf, count::Integer,
# MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
@mpichk ccall((:MPI_Scan, libmpi), Cint,
(MPIPtr, MPIPtr, Cint, MPI_Datatype, MPI_Op, MPI_Comm),
sendbuf, recvbuf, count, mpitype(T), op, comm)
sendbuf, recvbuf, count, Datatype(T), op, comm)
recvbuf
end
function Scan!(sendbuf, recvbuf, count::Integer, opfunc, comm::Comm)
Expand Down Expand Up @@ -840,7 +840,7 @@ function Exscan!(sendbuf, recvbuf, count::Integer,
# MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
@mpichk ccall((:MPI_Exscan, libmpi), Cint,
(MPIPtr, MPIPtr, Cint, MPI_Datatype, MPI_Op, MPI_Comm),
sendbuf, recvbuf, count, mpitype(T), op, comm)
sendbuf, recvbuf, count, Datatype(T), op, comm)
recvbuf
end
function Exscan!(sendbuf, recvbuf, count::Integer, opfunc, comm::Comm)
Expand Down
15 changes: 13 additions & 2 deletions src/cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,17 @@ function Base.unsafe_convert(::Type{MPIPtr}, buf::DeviceBuffer)
reinterpret(MPIPtr, buf.ptr)
end
# CuArrays > v1.3
function Base.unsafe_convert(::Type{MPIPtr}, buf::CuArray{T}) where T
reinterpret(MPIPtr, Base.unsafe_convert(CuPtr{T}, buf))
function Base.unsafe_convert(::Type{MPIPtr}, X::CuArray{T}) where T
reinterpret(MPIPtr, Base.unsafe_convert(CuPtr{T}, X))
end
# only need to define this for strided arrays: all others can be handled by generic machinery
function Base.unsafe_convert(::Type{MPIPtr}, V::SubArray{T,N,P,I,true}) where {T,N,P<:CuArray,I}
X = parent(V)
pX = Base.unsafe_convert(CuPtr{T}, X)
pV = pX + ((V.offset1 + V.stride1) - first(LinearIndices(X)))*sizeof(T)
return reinterpret(MPIPtr, pV)
end

function Buffer(arr::CuArray)
Buffer(arr, Cint(length(arr)), Datatype(eltype(arr)))
end
Loading

0 comments on commit f972016

Please sign in to comment.