Skip to content

Commit

Permalink
Minimize diff.
Browse files Browse the repository at this point in the history
  • Loading branch information
maleadt committed Sep 20, 2024
1 parent 472db83 commit 184b36f
Show file tree
Hide file tree
Showing 9 changed files with 117 additions and 118 deletions.
74 changes: 39 additions & 35 deletions lib/JLArrays/src/JLArrays.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
# reference implementation on the CPU
# This acts as a wrapper around KernelAbstractions's parallel CPU
# functionality. It is useful for testing GPUArrays (and other packages)
# functionality. It is useful for testing GPUArrays (and other packages)
# when no GPU is present.
# This file follows conventions from AMDGPU.jl

module JLArrays

export JLArray, JLVector, JLMatrix, jl, JLBackend

using GPUArrays

using Adapt

import KernelAbstractions
import KernelAbstractions: Adapt, StaticArrays, Backend, Kernel, StaticSize, DynamicSize, partition, blocks, workitems, launch_config

export JLArray, JLVector, JLMatrix, jl, JLBackend

#
# Device functionality
Expand All @@ -24,7 +27,6 @@ struct JLBackend <: KernelAbstractions.GPU
JLBackend(;static::Bool=false) = new(static)
end


struct Adaptor end
jlconvert(arg) = adapt(Adaptor(), arg)

Expand All @@ -35,37 +37,7 @@ end
Base.getindex(r::JlRefValue) = r.x
Adapt.adapt_structure(to::Adaptor, r::Base.RefValue) = JlRefValue(adapt(to, r[]))

mutable struct JLArray{T, N} <: AbstractGPUArray{T, N}
data::DataRef{Vector{UInt8}}

offset::Int # offset of the data in the buffer, in number of elements

dims::Dims{N}

# allocating constructor
function JLArray{T,N}(::UndefInitializer, dims::Dims{N}) where {T,N}
check_eltype(T)
maxsize = prod(dims) * sizeof(T)
data = Vector{UInt8}(undef, maxsize)
ref = DataRef(data) do data
resize!(data, 0)
end
obj = new{T,N}(ref, 0, dims)
finalizer(unsafe_free!, obj)
end

# low-level constructor for wrapping existing data
function JLArray{T,N}(ref::DataRef{Vector{UInt8}}, dims::Dims{N};
offset::Int=0) where {T,N}
check_eltype(T)
obj = new{T,N}(ref, offset, dims)
finalizer(unsafe_free!, obj)
end
end

Adapt.adapt_storage(::JLBackend, a::Array) = Adapt.adapt(JLArrays.JLArray, a)
Adapt.adapt_storage(::JLBackend, a::JLArrays.JLArray) = a
Adapt.adapt_storage(::KernelAbstractions.CPU, a::JLArrays.JLArray) = convert(Array, a)
## executed on-device

# array type

Expand All @@ -91,6 +63,7 @@ end
@inline Base.getindex(A::JLDeviceArray, index::Integer) = getindex(typed_data(A), index)
@inline Base.setindex!(A::JLDeviceArray, x, index::Integer) = setindex!(typed_data(A), x, index)


#
# Host abstractions
#
Expand All @@ -104,6 +77,34 @@ function check_eltype(T)
end
end

mutable struct JLArray{T, N} <: AbstractGPUArray{T, N}
data::DataRef{Vector{UInt8}}

offset::Int # offset of the data in the buffer, in number of elements

dims::Dims{N}

# allocating constructor
function JLArray{T,N}(::UndefInitializer, dims::Dims{N}) where {T,N}
check_eltype(T)
maxsize = prod(dims) * sizeof(T)
data = Vector{UInt8}(undef, maxsize)
ref = DataRef(data) do data
resize!(data, 0)
end
obj = new{T,N}(ref, 0, dims)
finalizer(unsafe_free!, obj)
end

# low-level constructor for wrapping existing data
function JLArray{T,N}(ref::DataRef{Vector{UInt8}}, dims::Dims{N};
offset::Int=0) where {T,N}
check_eltype(T)
obj = new{T,N}(ref, offset, dims)
finalizer(unsafe_free!, obj)
end
end

unsafe_free!(a::JLArray) = GPUArrays.unsafe_free!(a.data)

# conversion of untyped data to a typed Array
Expand Down Expand Up @@ -380,7 +381,10 @@ function (obj::Kernel{JLBackend})(args...; ndrange=nothing, workgroupsize=nothin
device_args = jlconvert.(args)
new_obj = convert_to_cpu(obj)
new_obj(device_args...; ndrange, workgroupsize)

end

Adapt.adapt_storage(::JLBackend, a::Array) = Adapt.adapt(JLArrays.JLArray, a)
Adapt.adapt_storage(::JLBackend, a::JLArrays.JLArray) = a
Adapt.adapt_storage(::KernelAbstractions.CPU, a::JLArrays.JLArray) = convert(Array, a)

end
5 changes: 3 additions & 2 deletions src/GPUArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@ using LLVM.Interop
using Reexport
@reexport using GPUArraysCore

## executed on-device
using KernelAbstractions

# device functionality
include("device/abstractarray.jl")

using KernelAbstractions
# host abstractions
include("host/abstractarray.jl")
include("host/construction.jl")
Expand Down
1 change: 0 additions & 1 deletion src/host/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,6 @@ for (D, S) in ((AnyGPUArray, Array),
end

# kernel-based variant for copying between wrapped GPU arrays
# TODO: Add `@Const` to `src`
@kernel function linear_copy_kernel!(dest, dstart, src, sstart, n)
i = @index(Global, Linear)
if i <= n
Expand Down
12 changes: 6 additions & 6 deletions src/host/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,16 @@ end
@inbounds dest[I] = bc[I]
end

# grid-stride kernel, ndrange set for possible 0D evaluation
if ndims(dest) == 1 || (isa(IndexStyle(dest), IndexLinear) &&
broadcast_kernel = if ndims(dest) == 1 ||
(isa(IndexStyle(dest), IndexLinear) &&
isa(IndexStyle(bc), IndexLinear))
broadcast_kernel_linear(get_backend(dest))(dest, bc;
ndrange = length(size(dest)) > 0 ? length(dest) : 1)
broadcast_kernel_linear(get_backend(dest))
else
broadcast_kernel_cartesian(get_backend(dest))(dest, bc;
ndrange = sz = length(size(dest)) > 0 ? size(dest) : (1,))
broadcast_kernel_cartesian(get_backend(dest))
end

# ndims check for 0D support
broadcast_kernel(dest, bc; ndrange = ndims(dest) > 0 ? size(dest) : (1,))
if eltype(dest) <: BrokenBroadcast
throw(ArgumentError("Broadcast operation resulting in $(eltype(eltype(dest))) is not GPU compatible"))
end
Expand Down
7 changes: 4 additions & 3 deletions src/host/construction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,15 @@ Base.convert(::Type{T}, a::AbstractArray) where {T<:AbstractGPUArray} = a isa T

function Base.fill!(A::AnyGPUArray{T}, x) where T
isempty(A) && return A

@kernel function fill_kernel!(a, val)
idx = @index(Global, Linear)
@inbounds a[idx] = val
end

# ndrange set for a possible 0D evaluation
fill_kernel!(get_backend(A))(A, x,
ndrange = length(size(A)) > 0 ? size(A) : (1,))
# ndims check for 0D support
kernel = fill_kernel!(get_backend(A))
kernel(A, x; ndrange = ndims(A) > 0 ? size(A) : (1,))
A
end

Expand Down
13 changes: 4 additions & 9 deletions src/host/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,11 @@ end
return vectorized_getindex!(dest, src, Is...)
end

@kernel function getindex_kernel(dest, src, idims,
Is::Vararg{Any,N}) where {N}
@kernel function getindex_kernel(dest, src, idims, Is...)
i = @index(Global, Linear)
getindex_generated(dest, src, idims, i, Is...)
end

@generated function getindex_generated(dest, src, idims, i,
Is::Vararg{Any,N}) where {N}
@generated function getindex_generated(dest, src, idims, i, Is::Vararg{Any,N}) where {N}
quote
is = @inbounds CartesianIndices(idims)[i]
@nexprs $N i -> I_i = @inbounds(Is[i][is[i]])
Expand Down Expand Up @@ -120,13 +117,11 @@ end
return dest
end

@kernel function setindex_kernel(dest, src, idims, len,
Is::Vararg{Any,N}) where {N}
@kernel function setindex_kernel(dest, src, idims, len, Is...)
i = @index(Global, Linear)
setindex_generated(dest, src, idims, len, i, Is...)
end
@generated function setindex_generated(dest, src, idims, len, i,
Is::Vararg{Any,N}) where {N}
@generated function setindex_generated(dest, src, idims, len, i, Is::Vararg{Any,N}) where {N}
quote
i > len && return
is = @inbounds CartesianIndices(idims)[i]
Expand Down
Loading

0 comments on commit 184b36f

Please sign in to comment.