diff --git a/Project.toml b/Project.toml index 6a6fa5ee..6cb7bbc2 100644 --- a/Project.toml +++ b/Project.toml @@ -17,9 +17,9 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [weakdeps] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" -cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" +cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [extensions] NNlibAMDGPUExt = "AMDGPU" @@ -33,7 +33,6 @@ AMDGPU = "0.9.4" Adapt = "3.2, 4" Atomix = "0.1" CUDA = "4, 5" -cuDNN = "1" ChainRulesCore = "1.13" EnzymeCore = "0.5, 0.6, 0.7" FFTW = "1.8.0" @@ -44,4 +43,5 @@ Pkg = "<0.0.1, 1" Random = "<0.0.1, 1" Requires = "1.0" Statistics = "1" +cuDNN = "1" julia = "1.9" diff --git a/src/padding.jl b/src/padding.jl index 903212d6..76aa1ddd 100644 --- a/src/padding.jl +++ b/src/padding.jl @@ -270,8 +270,12 @@ function pad_reflect( ) where {F,N} lpad, rpad = pad n = size(x, dims) - xl = lpad == 0 ? similar(x, 0) : reverse(selectdim(x, dims, 2:lpad+1); dims) - xr = rpad == 0 ? similar(x, 0) : reverse(selectdim(x, dims, n-rpad:n-1); dims) + xl = lpad == 0 ? + similar(x, ntuple(i -> i == dims ? 0 : size(x, i), ndims(x))) : + reverse(selectdim(x, dims, 2:lpad+1); dims) + xr = rpad == 0 ? + similar(x, ntuple(i -> i == dims ? 0 : size(x, i), ndims(x))) : + reverse(selectdim(x, dims, n-rpad:n-1); dims) return cat(xl, x, xr; dims) end @@ -326,8 +330,12 @@ function pad_symmetric( lpad, rpad = pad n = size(x, dims) - xl = lpad == 0 ? similar(x, 0) : reverse(selectdim(x, dims, 1:lpad); dims) - xr = rpad == 0 ? similar(x, 0) : reverse(selectdim(x, dims, n-rpad+1:n); dims) + xl = lpad == 0 ? + similar(x, ntuple(i -> i == dims ? 0 : size(x, i), ndims(x))) : + reverse(selectdim(x, dims, 1:lpad); dims) + xr = rpad == 0 ? + similar(x, ntuple(i -> i == dims ? 0 : size(x, i), ndims(x))) : + reverse(selectdim(x, dims, n-rpad+1:n); dims) return cat(xl, x, xr; dims) end diff --git a/test/padding.jl b/test/padding.jl index f4d7b11e..a066d054 100644 --- a/test/padding.jl +++ b/test/padding.jl @@ -1,68 +1,68 @@ using NNlib: pad_constant, pad_repeat, pad_zeros, pad_reflect, pad_symmetric, pad_circular @testset "padding constant" begin - x = rand(2, 2, 2) - + x = rand(2, 2, 2) + p = NNlib.gen_pad((1,2,3,4,5,6), (1,2,3), 4) @test p == ((1, 2), (3, 4), (5, 6), (0, 0)) - + @test_throws ArgumentError NNlib.gen_pad((1,2,3,4,5,), (1,2,3), 4) - + p = NNlib.gen_pad((1,3), (1,3), 4) @test p == ((1, 1), (0, 0), (3, 3), (0, 0)) - + p = NNlib.gen_pad(1, (1,2,3), 4) @test p == ((1, 1), (1, 1), (1, 1), (0, 0)) - + p = NNlib.gen_pad(3, :, 2) @test p == ((3, 3), (3, 3)) p = NNlib.gen_pad((1,0), 1, 2) @test p == ((1,0), (0,0)) - + y = pad_constant(x, (3, 2, 4)) @test size(y) == (8, 6, 10) @test y[4:5, 3:4, 5:6] ≈ x y[4:5, 3:4, 5:6] .= 0 @test all(y .== 0) - + @test pad_constant(x, (3, 2, 4)) ≈ pad_zeros(x, (3, 2, 4)) - @test pad_zeros(x, 2) ≈ pad_zeros(x, (2,2,2)) - + @test pad_zeros(x, 2) ≈ pad_zeros(x, (2,2,2)) + y = pad_constant(x, (3, 2, 4, 5), 1.2, dims = (1,3)) @test size(y) == (7, 2, 11) @test y[4:5, 1:2, 5:6] ≈ x y[4:5, 1:2, 5:6] .= 1.2 @test all(y .== 1.2) - + @test pad_constant(x, (2,2,2,2), 1.2, dims = (1,3)) ≈ pad_constant(x, 2, 1.2, dims = (1,3)) - + @test pad_constant(x, 1, dims = 1:2) == - pad_constant(x, 1, dims = (1,2)) - + pad_constant(x, 1, dims = (1,2)) + @test size(pad_constant(x, 1, dims = 1)) == (4,2,2) - + @test all(pad_zeros(randn(2), (1, 2))[[1, 4, 5]] .== 0) - + gradtest(x -> pad_constant(x, 2), rand(2,2,2)) gradtest(x -> pad_constant(x, (2, 1, 1, 2)), rand(2,2)) gradtest(x -> pad_constant(x, (2, 1,)), rand(2)) end @testset "padding repeat" begin - x = rand(2, 2, 2) - + x = rand(2, 2, 2) + # y = @inferred pad_repeat(x, (3, 2, 4, 5)) y = pad_repeat(x, (3, 2, 4, 5)) @test size(y) == (7, 11, 2) @test y[4:5, 5:6, :] ≈ x - + # y = @inferred pad_repeat(x, (3, 2, 4, 5), dims=(1,3)) y = pad_repeat(x, (3, 2, 4, 5), dims=(1,3)) @test size(y) == (7, 2, 11) @test y[4:5, :, 5:6] ≈ x - + @test pad_repeat(reshape(1:9, 3, 3), (1,2)) == [1 4 7 1 4 7 @@ -70,15 +70,15 @@ end 3 6 9 3 6 9 3 6 9] - + @test pad_repeat(reshape(1:9, 3, 3), (2,2), dims=2) == [1 1 1 4 7 7 7 2 2 2 5 8 8 8 3 3 3 6 9 9 9] - + @test pad_repeat(x, (2, 2, 2, 2), dims=(1,3)) ≈ pad_repeat(x, 2, dims=(1,3)) - + gradtest(x -> pad_repeat(x, (2,2,2,2)), rand(2,2,2)) end @@ -87,7 +87,7 @@ end @test y == [7 4 1 4 7 4 1 8 5 2 5 8 5 2 9 6 3 6 9 6 3] - + y = pad_reflect(reshape(1:9, 3, 3), (2,2,2,2)) @test y == [9 6 3 6 9 6 3 8 5 2 5 8 5 2 @@ -96,14 +96,26 @@ end 9 6 3 6 9 6 3 8 5 2 5 8 5 2 7 4 1 4 7 4 1] - - x = rand(4, 4, 4) + + x = rand(4, 4, 4) @test pad_reflect(x, (2, 2, 2, 2), dims=(1,3)) ≈ pad_reflect(x, 2, dims=(1,3)) - - # pad_reflect needs larger test input as padding must + + # pad_reflect needs larger test input as padding must # be strictly less than array size in that dimension gradtest(x -> pad_reflect(x, (2,2,2,2)), rand(3,3,3)) + + x = reshape(1:9, 3, 3, 1, 1) + @test NNlib.pad_reflect(x, (1, 0, 1, 0); dims=1:2) == [ + 5 2 5 8; + 4 1 4 7; + 5 2 5 8; + 6 3 6 9;;;;] + @test NNlib.pad_reflect(x, (0, 1, 0, 1); dims=1:2) == [ + 1 4 7 4; + 2 5 8 5; + 3 6 9 6; + 2 5 8 5;;;;] end @testset "padding symmetric" begin @@ -111,7 +123,7 @@ end @test y == [4 1 1 4 7 7 4 5 2 2 5 8 8 5 6 3 3 6 9 9 6] - + y = pad_symmetric(reshape(1:9, 3, 3), (2,2,2,2)) @test y == [5 2 2 5 8 8 5 4 1 1 4 7 7 4 @@ -120,12 +132,24 @@ end 6 3 3 6 9 9 6 6 3 3 6 9 9 6 5 2 2 5 8 8 5] - - x = rand(4, 4, 4) + + x = rand(4, 4, 4) @test pad_symmetric(x, (2, 2, 2, 2), dims=(1,3)) ≈ pad_symmetric(x, 2, dims=(1,3)) - + gradtest(x -> pad_symmetric(x, (2,2,2,2)), rand(2,2,2)) + + x = reshape(1:9, 3, 3, 1, 1) + @test NNlib.pad_symmetric(x, (1, 0, 1, 0); dims=1:2) == [ + 1 1 4 7; + 1 1 4 7; + 2 2 5 8; + 3 3 6 9;;;;] + @test NNlib.pad_symmetric(x, (0, 1, 0, 1); dims=1:2) == [ + 1 4 7 7; + 2 5 8 8; + 3 6 9 9; + 3 6 9 9;;;;] end @testset "padding circular" begin @@ -133,7 +157,7 @@ end @test y == [4 7 1 4 7 1 4 5 8 2 5 8 2 5 6 9 3 6 9 3 6] - + y = pad_circular(reshape(1:9, 3, 3), (2,2,2,2)) @test y == [5 8 2 5 8 2 5 6 9 3 6 9 3 6 @@ -142,10 +166,10 @@ end 6 9 3 6 9 3 6 4 7 1 4 7 1 4 5 8 2 5 8 2 5] - - x = rand(4, 4, 4) + + x = rand(4, 4, 4) @test pad_circular(x, (2, 2, 2, 2), dims=(1,3)) ≈ pad_circular(x, 2, dims=(1,3)) - + gradtest(x -> pad_circular(x, (2,2,2,2)), rand(2,2,2)) end