From ac48189bdab1fd31af527ba2c907b0e085950bcc Mon Sep 17 00:00:00 2001 From: Stephan Hilb Date: Sat, 27 Oct 2018 20:21:26 +0200 Subject: [PATCH] base: make diff() use views and broadcasting --- base/multidimensional.jl | 22 ++++++++++------------ test/arrayops.jl | 6 ++++++ 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/base/multidimensional.jl b/base/multidimensional.jl index d79715299fffd1..839b723585fc83 100644 --- a/base/multidimensional.jl +++ b/base/multidimensional.jl @@ -659,10 +659,7 @@ end end end -function diff(a::AbstractVector) - @assert !has_offset_axes(a) - [ a[i+1] - a[i] for i=1:length(a)-1 ] -end +diff(a::AbstractVector) = diff(a, dims=1) """ diff(A::AbstractVector) @@ -690,14 +687,15 @@ julia> diff(vec(a)) 12 ``` """ -function diff(A::AbstractMatrix; dims::Integer) - if dims == 1 - [A[i+1,j] - A[i,j] for i=1:size(A,1)-1, j=1:size(A,2)] - elseif dims == 2 - [A[i,j+1] - A[i,j] for i=1:size(A,1), j=1:size(A,2)-1] - else - throw(ArgumentError("dimension must be 1 or 2, got $dims")) - end +function diff(a::AbstractArray{T,N}; dims::Integer) where {T,N} + has_offset_axes(a) && throw(ArgumentError("offset axes unsupported")) + 1 <= dims <= N || throw(ArgumentError("dimension $dims out of range (1:$N)")) + + r = axes(a) + r0 = ntuple(i -> i == dims ? UnitRange(1, last(r[i]) - 1) : UnitRange(r[i]), N) + r1 = ntuple(i -> i == dims ? UnitRange(2, last(r[i])) : UnitRange(r[i]), N) + + return view(a, r1...) .- view(a, r0...) end ### from abstractarray.jl diff --git a/test/arrayops.jl b/test/arrayops.jl index 83f75dd97d068e..15c647da4da7df 100644 --- a/test/arrayops.jl +++ b/test/arrayops.jl @@ -2283,6 +2283,9 @@ end @testset "diff" begin # test diff, throw ArgumentError for invalid dimension argument + v = [7, 3, 5, 1, 9] + @test diff(v) == [-4, 2, -4, 8] + @test diff(v,dims=1) == [-4, 2, -4, 8] X = [3 9 5; 7 4 2; 2 1 10] @@ -2292,6 +2295,9 @@ end @test diff(view(X, 1:2, 1:2),dims=2) == reshape([6; -3], (2,1)) @test diff(view(X, 2:3, 2:3),dims=1) == [-3 8] @test diff(view(X, 2:3, 2:3),dims=2) == reshape([-2; 9], (2,1)) + Y = cat([1 3; 4 3], [6 5; 1 4], dims=3) + @test diff(Y, dims=3) == reshape([5 2; -3 1], (2, 2, 1)) + @test_throws UndefKeywordError diff(X) @test_throws ArgumentError diff(X,dims=3) @test_throws ArgumentError diff(X,dims=-1) end