From 053a171124e2c64b1457aae8c18abcc9cd93494b Mon Sep 17 00:00:00 2001 From: fverdugo Date: Fri, 7 Jun 2019 08:32:45 +0200 Subject: [PATCH] Added ArrayKernels --- src/Kernels.jl | 54 ++++++++++++++++++++++++++++++++++++++++++++ test/KernelsTests.jl | 31 ++++++++++++++++++++++++- 2 files changed, 84 insertions(+), 1 deletion(-) diff --git a/src/Kernels.jl b/src/Kernels.jl index 69e4cc1..acda220 100644 --- a/src/Kernels.jl +++ b/src/Kernels.jl @@ -4,10 +4,15 @@ using Test using CellwiseValues export NumberKernel +export ArrayKernel export NumberKernelFromFunction +export ArrayKernelFromBroadcastedFunction export compute_type export compute_value +export compute_value! +export compute_size export test_number_kernel +export test_array_kernel # Interfaces @@ -21,6 +26,20 @@ function compute_value(::NumberKernel,::Vararg)::NumberLike @abstractmethod end +abstract type ArrayKernel end + +function compute_type(::ArrayKernel,::Vararg{<:Type})::Type{<:NumberLike} + @abstractmethod +end + +function compute_size(::ArrayKernel,::Vararg{<:NTuple})::NTuple + @abstractmethod +end + +function compute_value!(::AbstractArray,::NumberKernel,::Vararg) + @abstractmethod +end + # Testers function test_number_kernel(k::NumberKernel,o::T,i::Vararg) where T @@ -31,6 +50,23 @@ function test_number_kernel(k::NumberKernel,o::T,i::Vararg) where T @test r == o end +function test_array_kernel( + k::ArrayKernel,o::AbstractArray{T,N},i::Vararg) where {T,N} + t = [ eltype(ii) for ii in i ] + S = compute_type(k,t...) + @test S == T + s = [ _size_for_broadcast(ii) for ii in i ] + si = compute_size(k,s...) + @test size(o) == si + r = Array{T,N}(undef,si) + compute_value!(r,k,i...) + @test r == o +end + +_size_for_broadcast(a) = size(a) + +_size_for_broadcast(a::NumberLike) = () + # Implementations struct NumberKernelFromFunction{F<:Function} <: NumberKernel @@ -48,4 +84,22 @@ function compute_value(self::NumberKernelFromFunction,args::Vararg) self.fun(args...) end +struct ArrayKernelFromBroadcastedFunction{F<:Function} <: ArrayKernel + fun::F +end + +function compute_type( + self::ArrayKernelFromBroadcastedFunction,etypes::Vararg{<:Type}) + Base._return_type(self.fun,etypes) +end + +function compute_size(::ArrayKernelFromBroadcastedFunction,s::Vararg{<:NTuple}) + Base.Broadcast.broadcast_shape(s...) +end + +function compute_value!( + v::AbstractArray, k::ArrayKernelFromBroadcastedFunction, a::Vararg) + broadcast!(k.fun,v,a...) +end + end # module Kernels diff --git a/test/KernelsTests.jl b/test/KernelsTests.jl index 60197de..8bcac5b 100644 --- a/test/KernelsTests.jl +++ b/test/KernelsTests.jl @@ -1,7 +1,6 @@ module KernelsTests using CellwiseValues -using StaticArrays using TensorValues k = NumberKernelFromFunction(-) @@ -14,4 +13,34 @@ k = NumberKernelFromFunction(sum) test_number_kernel(k,7,[1,2,4]) +k = ArrayKernelFromBroadcastedFunction(*) + +test_array_kernel(k,[3,4,3],[1,2,3],[3,2,1]) + +test_array_kernel(k,[6,4,2],2,[3,2,1]) + +test_array_kernel(k,[6,4,2],[3,2,1],2) + +test_array_kernel(k,[6,4,2],[3,2,1],2,1) + +u = rand(Int,2,3,1) +v = rand(Int,1,3,4) +w = u .* v +test_array_kernel(k,w,u,v) +test_array_kernel(k,w,u,v,1) + +k = ArrayKernelFromBroadcastedFunction(+) + +v1 = VectorValue(2,3) +v2 = VectorValue(3,2) +v3 = VectorValue(1,2) +u = [v1,v2,v3] +v = v1 +w = u .+ v +test_array_kernel(k,w,u,v) + +w = broadcast(+,u,v,0) +test_array_kernel(k,w,u,v,0) + + end # module KernelsTests