Skip to content

Commit

Permalink
Added ArrayKernels
Browse files Browse the repository at this point in the history
  • Loading branch information
fverdugo committed Jun 7, 2019
1 parent 59000ef commit 053a171
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 1 deletion.
54 changes: 54 additions & 0 deletions src/Kernels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,15 @@ using Test
using CellwiseValues

export NumberKernel
export ArrayKernel
export NumberKernelFromFunction
export ArrayKernelFromBroadcastedFunction
export compute_type
export compute_value
export compute_value!
export compute_size
export test_number_kernel
export test_array_kernel

# Interfaces

Expand All @@ -21,6 +26,20 @@ function compute_value(::NumberKernel,::Vararg)::NumberLike
@abstractmethod
end

abstract type ArrayKernel end

function compute_type(::ArrayKernel,::Vararg{<:Type})::Type{<:NumberLike}
@abstractmethod
end

function compute_size(::ArrayKernel,::Vararg{<:NTuple})::NTuple
@abstractmethod
end

function compute_value!(::AbstractArray,::NumberKernel,::Vararg)
@abstractmethod
end

# Testers

function test_number_kernel(k::NumberKernel,o::T,i::Vararg) where T
Expand All @@ -31,6 +50,23 @@ function test_number_kernel(k::NumberKernel,o::T,i::Vararg) where T
@test r == o
end

function test_array_kernel(
k::ArrayKernel,o::AbstractArray{T,N},i::Vararg) where {T,N}
t = [ eltype(ii) for ii in i ]
S = compute_type(k,t...)
@test S == T
s = [ _size_for_broadcast(ii) for ii in i ]
si = compute_size(k,s...)
@test size(o) == si
r = Array{T,N}(undef,si)
compute_value!(r,k,i...)
@test r == o
end

_size_for_broadcast(a) = size(a)

_size_for_broadcast(a::NumberLike) = ()

# Implementations

struct NumberKernelFromFunction{F<:Function} <: NumberKernel
Expand All @@ -48,4 +84,22 @@ function compute_value(self::NumberKernelFromFunction,args::Vararg)
self.fun(args...)
end

struct ArrayKernelFromBroadcastedFunction{F<:Function} <: ArrayKernel
fun::F
end

function compute_type(
self::ArrayKernelFromBroadcastedFunction,etypes::Vararg{<:Type})
Base._return_type(self.fun,etypes)
end

function compute_size(::ArrayKernelFromBroadcastedFunction,s::Vararg{<:NTuple})
Base.Broadcast.broadcast_shape(s...)
end

function compute_value!(
v::AbstractArray, k::ArrayKernelFromBroadcastedFunction, a::Vararg)
broadcast!(k.fun,v,a...)
end

end # module Kernels
31 changes: 30 additions & 1 deletion test/KernelsTests.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
module KernelsTests

using CellwiseValues
using StaticArrays
using TensorValues

k = NumberKernelFromFunction(-)
Expand All @@ -14,4 +13,34 @@ k = NumberKernelFromFunction(sum)

test_number_kernel(k,7,[1,2,4])

k = ArrayKernelFromBroadcastedFunction(*)

test_array_kernel(k,[3,4,3],[1,2,3],[3,2,1])

test_array_kernel(k,[6,4,2],2,[3,2,1])

test_array_kernel(k,[6,4,2],[3,2,1],2)

test_array_kernel(k,[6,4,2],[3,2,1],2,1)

u = rand(Int,2,3,1)
v = rand(Int,1,3,4)
w = u .* v
test_array_kernel(k,w,u,v)
test_array_kernel(k,w,u,v,1)

k = ArrayKernelFromBroadcastedFunction(+)

v1 = VectorValue(2,3)
v2 = VectorValue(3,2)
v3 = VectorValue(1,2)
u = [v1,v2,v3]
v = v1
w = u .+ v
test_array_kernel(k,w,u,v)

w = broadcast(+,u,v,0)
test_array_kernel(k,w,u,v,0)


end # module KernelsTests

0 comments on commit 053a171

Please sign in to comment.