diff --git a/src/dual.jl b/src/dual.jl index 75d46378..625504ab 100644 --- a/src/dual.jl +++ b/src/dual.jl @@ -68,7 +68,7 @@ end @inline Dual{T}(value, partials::Tuple{}) where {T} = Dual{T}(value, Partials{0,typeof(value)}(partials)) @inline Dual{T}(value) where {T} = Dual{T}(value, ()) @inline Dual{T}(x::Dual{T}) where {T} = Dual{T}(x, ()) -@inline Dual{T}(value, partial1, partials...) where {T} = Dual{T}(value, tuple(partial1, partials...)) +@inline Dual{T}(value, partial1, partials...) where {T} = Dual{T}(value, tuple(VecElement(partial1), VecElement.(partials)...)) @inline Dual{T}(value::V, ::Chunk{N}, p::Val{i}) where {T,V,N,i} = Dual{T}(value, single_seed(Partials{N,V}, p)) @inline Dual(args...) = Dual{Nothing}(args...) diff --git a/src/partials.jl b/src/partials.jl index 7a94884e..c693ef3d 100644 --- a/src/partials.jl +++ b/src/partials.jl @@ -1,5 +1,5 @@ struct Partials{N,V} <: AbstractVector{V} - values::NTuple{N,V} + values::NTuple{N,VecElement{V}} end ############################## @@ -7,7 +7,7 @@ end ############################## @generated function single_seed(::Type{Partials{N,V}}, ::Val{i}) where {N,V,i} - ex = Expr(:tuple, [ifelse(i === j, :(one(V)), :(zero(V))) for j in 1:N]...) + ex = Expr(:tuple, [ifelse(i === j, :(VecElement(one(V))), :(VecElement(zero(V)))) for j in 1:N]...) return :(Partials($(ex))) end @@ -20,10 +20,11 @@ end @inline Base.length(::Partials{N}) where {N} = N @inline Base.size(::Partials{N}) where {N} = (N,) -@inline Base.@propagate_inbounds Base.getindex(partials::Partials, i::Int) = partials.values[i] +@inline Base.@propagate_inbounds Base.getindex(partials::Partials, i::Int) = partials.values[i].value -Base.iterate(partials::Partials) = iterate(partials.values) -Base.iterate(partials::Partials, i) = iterate(partials.values, i) + +Base.iterate(partials::Partials{0}) = nothing +Base.iterate(partials::Partials, i=1) = i > length(partials) ? nothing : (partials.values[i].value, i+1) Base.IndexStyle(::Type{<:Partials}) = IndexLinear() @@ -47,8 +48,11 @@ Base.mightalias(x::AbstractArray, y::Partials) = false @inline Random.rand(rng::AbstractRNG, partials::Partials) = rand(rng, typeof(partials)) @inline Random.rand(rng::AbstractRNG, ::Type{Partials{N,V}}) where {N,V} = Partials{N,V}(rand_tuple(rng, NTuple{N,V})) -Base.isequal(a::Partials{N}, b::Partials{N}) where {N} = isequal(a.values, b.values) -Base.:(==)(a::Partials{N}, b::Partials{N}) where {N} = a.values == b.values +Base.isequal(a::Partials{N}, b::Partials{N}) where {N} = all(i->isequal(a[i], b[i]), 1:N) +Base.:(==)(a::Partials{N}, b::Partials{N}) where {N} = all(i->a[i] == b[i], 1:N) +#Base.:(==)(a::Partials{N}, b::NTuple{N}) where {N} = all(i->a[i] == b[i], 1:N) +#Base.:(==)(a::NTuple{N}, b::Partials{N}) where {N} = all(i->a[i] == b[i], 1:N) + const PARTIALS_HASH = hash(Partials) @@ -71,7 +75,7 @@ end Base.promote_rule(::Type{Partials{N,A}}, ::Type{Partials{N,B}}) where {N,A,B} = Partials{N,promote_type(A, B)} -Base.convert(::Type{Partials{N,V}}, partials::Partials) where {N,V} = Partials{N,V}(partials.values) +Base.convert(::Type{Partials{N,V}}, partials::Partials) where {N,V} = Partials{N,V}(ntuple(i->VecElement(V(partials[i])), N)) Base.convert(::Type{Partials{N,V}}, partials::Partials{N,V}) where {N,V} = partials ######################## @@ -157,6 +161,11 @@ const SIMDType = Union{SIMDFloat, SIMDInt} # faster since they generate inline code # that doesn't rely on closures. + +const NVE{N,T} = NTuple{N,VecElement{T}} +const NT{N,T} = NTuple{N,T} + + function tupexpr(f, N) ex = Expr(:tuple, [f(i) for i=1:N]...) return quote @@ -171,11 +180,11 @@ end @inline rand_tuple(::AbstractRNG, ::Type{Tuple{}}) = tuple() @inline rand_tuple(::Type{Tuple{}}) = tuple() -iszero_tuple(tup::NTuple{N,V}) where {N, V<:SIMDType} = sum(Vec(tup) != zero(V)) == 0 -@generated function iszero_tuple(tup::NTuple{N,V}) where {N,V} - ex = Expr(:&&, [:(z == tup[$i]) for i=1:N]...) +iszero_tuple(tup::NVE{N,V}) where {N, V<:SIMDType} = sum(Vec(tup) != zero(V)) == 0 +@generated function iszero_tuple(tup::NTuple{N,VecElement{V}}) where {N,V} + ex = Expr(:&&, [:(z == tup[$i].value) for i=1:N]...) return quote - z = zero(V) + z = VecElement(zero(V)) $(Expr(:meta, :inline)) @inbounds return $ex end @@ -184,7 +193,7 @@ end @generated function zero_tuple(::Type{NTuple{N,V}}) where {N,V} ex = tupexpr(i -> :(z), N) return quote - z = zero(V) + z = VecElement(zero(V)) return $ex end end @@ -205,24 +214,22 @@ end return tupexpr(i -> :(rand(V)), N) end -const NT{N,T} = NTuple{N,T} - # SIMD implementation -@inline add_tuples(a::NT{N,T}, b::NT{N,T}) where {N, T<:SIMDType} = Tuple(Vec(a) + Vec(b)) -@inline sub_tuples(a::NT{N,T}, b::NT{N,T}) where {N, T<:SIMDType} = Tuple(Vec(a) - Vec(b)) -@inline scale_tuple(tup::NT{N,T}, x::T) where {N, T<:SIMDType} = Tuple(Vec(tup) * x) -@inline div_tuple_by_scalar(tup::NT{N,T}, x::T) where {N, T<:SIMDFloat} = Tuple(Vec(tup) / x) -@inline minus_tuple(tup::NT{N,T}) where {N, T<:SIMDType} = Tuple(-Vec(tup)) -@inline mul_tuples(a::NT{N,T}, b::NT{N,T}, af::T, bf::T) where {N, T<:SIMDType} = Tuple(muladd(af, Vec(a), bf * Vec(b))) +@inline add_tuples(a::NVE{N,T}, b::NVE{N,T}) where {N, T<:SIMDType} = (Vec(a) + Vec(b)).data +@inline sub_tuples(a::NVE{N,T}, b::NVE{N,T}) where {N, T<:SIMDType} = (Vec(a) - Vec(b)).data +@inline scale_tuple(tup::NVE{N,T}, x::T) where {N, T<:SIMDType} = (Vec(tup) * x).data +@inline div_tuple_by_scalar(tup::NVE{N,T}, x::T) where {N, T<:SIMDFloat} = (Vec(tup) / x).data +@inline minus_tuple(tup::NVE{N,T}) where {N, T<:SIMDType} = (-Vec(tup)).data +@inline mul_tuples(a::NVE{N,T}, b::NVE{N,T}, af::T, bf::T) where {N, T<:SIMDType} = (muladd(af, Vec(a), bf * Vec(b))).data # Fallback implementations -@generated add_tuples(a::NT{N}, b::NT{N}) where N = tupexpr(i -> :(a[$i] + b[$i]), N) -@generated sub_tuples(a::NT{N}, b::NT{N}) where N = tupexpr(i -> :(a[$i] - b[$i]), N) -@generated scale_tuple(tup::NT{N}, x) where N = tupexpr(i -> :(tup[$i] * x), N) -@generated div_tuple_by_scalar(tup::NT{N}, x) where N = tupexpr(i -> :(tup[$i] / x), N) -@generated minus_tuple(tup::NT{N}) where N = tupexpr(i -> :(-tup[$i]), N) -@generated mul_tuples(a::NT{N}, b::NT{N}, af, bf) where N = tupexpr(i -> :(muladd(af, a[$i], bf * b[$i])), N) +@generated add_tuples(a::NT{N}, b::NT{N}) where N = tupexpr(i -> :(VecElement(a[$i].value + b[$i].value)), N) +@generated sub_tuples(a::NT{N}, b::NT{N}) where N = tupexpr(i -> :(VecElement(a[$i].value - b[$i].value).value), N) +@generated scale_tuple(tup::NT{N}, x) where N = tupexpr(i -> :(VecElement(tup[$i].value * x)), N) +@generated div_tuple_by_scalar(tup::NT{N}, x) where N = tupexpr(i -> :(VecElement(tup[$i].value / x)), N) +@generated minus_tuple(tup::NT{N}) where N = tupexpr(i -> :(VecElement(-tup[$i].value)), N) +@generated mul_tuples(a::NT{N}, b::NT{N}, af, bf) where N = tupexpr(i -> :(VecElement(muladd(af, a[$i].value, bf * b[$i].value))), N) ################### # Pretty Printing #