-
Notifications
You must be signed in to change notification settings - Fork 80
/
broadcast.jl
138 lines (113 loc) · 5.31 KB
/
broadcast.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# broadcasting operations
export AbstractGPUArrayStyle
using Base.Broadcast
import Base.Broadcast: BroadcastStyle, Broadcasted, AbstractArrayStyle
const BroadcastGPUArray{T} = Union{AnyGPUArray{T},
Base.RefValue{<:AbstractGPUArray{T}}}
"""
Abstract supertype for GPU array styles. The `N` parameter is the dimensionality.
Downstream implementations should provide a concrete array style type that inherits from
this supertype.
"""
abstract type AbstractGPUArrayStyle{N} <: AbstractArrayStyle{N} end
# Wrapper types otherwise forget that they are GPU compatible
# NOTE: don't directly use GPUArrayStyle here not to lose downstream customizations.
BroadcastStyle(W::Type{<:WrappedGPUArray})= BroadcastStyle(Adapt.parent(W){Adapt.eltype(W), Adapt.ndims(W)})
backend(W::Type{<:WrappedGPUArray}) = backend(Adapt.parent(W){Adapt.eltype(W), Adapt.ndims(W)})
# Ref is special: it's not a real wrapper, so not part of Adapt,
# but it is commonly used to bypass broadcasting of an argument
# so we need to preserve its dimensionless properties.
BroadcastStyle(::Type{Base.RefValue{AT}}) where {AT<:AbstractGPUArray} =
typeof(BroadcastStyle(AT))(Val(0))
backend(::Type{Base.RefValue{AT}}) where {AT<:AbstractGPUArray} = backend(AT)
# but make sure we don't dispatch to the optimized copy method that directly indexes
function Broadcast.copy(bc::Broadcasted{<:AbstractGPUArrayStyle{0}})
ElType = Broadcast.combine_eltypes(bc.f, bc.args)
isbitstype(ElType) || error("Cannot broadcast function returning non-isbits $ElType.")
dest = copyto!(similar(bc, ElType), bc)
return @allowscalar dest[CartesianIndex()] # 0D broadcast needs to unwrap results
end
# we need to override the outer copy method to make sure we never fall back to scalar
# iteration (see, e.g., CUDA.jl#145)
@inline function Broadcast.copy(bc::Broadcasted{<:AbstractGPUArrayStyle})
ElType = Broadcast.combine_eltypes(bc.f, bc.args)
if !Base.isconcretetype(ElType)
error("""GPU broadcast resulted in non-concrete element type $ElType.
This probably means that the function you are broadcasting contains an error or type instability.""")
end
copyto!(similar(bc, ElType), bc)
end
@inline function Base.copyto!(dest::BroadcastGPUArray, bc::Broadcasted{Nothing})
axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc))
isempty(dest) && return dest
bc′ = Broadcast.preprocess(dest, bc)
# grid-stride kernel
function broadcast_kernel(ctx, dest, bc′, nelem)
for i in 1:nelem
I = @cartesianidx(dest, i)
@inbounds dest[I] = bc′[I]
end
return
end
elements = length(dest)
elements_per_thread = typemax(Int)
heuristic = launch_heuristic(backend(dest), broadcast_kernel, dest, bc′, 1;
elements, elements_per_thread)
config = launch_configuration(backend(dest), heuristic;
elements, elements_per_thread)
gpu_call(broadcast_kernel, dest, bc′, config.elements_per_thread;
threads=config.threads, blocks=config.blocks)
return dest
end
# Base defines this method as a performance optimization, but we don't know how to do
# `fill!` in general for all `BroadcastGPUArray` so we just go straight to the fallback
@inline Base.copyto!(dest::BroadcastGPUArray, bc::Broadcasted{<:Broadcast.AbstractArrayStyle{0}}) =
copyto!(dest, convert(Broadcasted{Nothing}, bc))
## map
allequal(x) = true
allequal(x, y, z...) = x == y && allequal(y, z...)
function Base.map(f, x::BroadcastGPUArray, xs::AbstractArray...)
# if argument sizes match, their shape needs to be preserved
xs = (x, xs...)
if allequal(size.(xs)...)
return f.(xs...)
end
# if not, treat them as iterators
indices = LinearIndices.(xs)
common_length = minimum(length.(indices))
# construct a broadcast to figure out the destination container
ElType = Broadcast.combine_eltypes(f, xs)
isbitstype(ElType) || error("Cannot map function returning non-isbits $ElType.")
dest = similar(x, ElType, common_length)
return map!(f, dest, xs...)
end
function Base.map!(f, dest::BroadcastGPUArray, xs::AbstractArray...)
# custom broadcast, ignoring the container size mismatches
# (avoids the reshape + view that our mapreduce impl has to do)
indices = LinearIndices.((dest, xs...))
common_length = minimum(length.(indices))
common_length==0 && return
bc = Broadcast.instantiate(Broadcast.broadcasted(f, xs...))
if bc isa Broadcast.Broadcasted
bc = Broadcast.preprocess(dest, bc)
end
# grid-stride kernel
function map_kernel(ctx, dest, bc, nelem)
for i in 1:nelem
j = linear_index(ctx, i)
j > common_length && return
J = CartesianIndices(axes(bc))[j]
@inbounds dest[j] = bc[J]
end
return
end
elements = common_length
elements_per_thread = typemax(Int)
heuristic = launch_heuristic(backend(dest), map_kernel, dest, bc, 1;
elements, elements_per_thread)
config = launch_configuration(backend(dest), heuristic;
elements, elements_per_thread)
gpu_call(map_kernel, dest, bc, config.elements_per_thread;
threads=config.threads, blocks=config.blocks)
return dest
end