From 018cfdfd4824fcd0fcbf0ce3e7330645d7f36abc Mon Sep 17 00:00:00 2001 From: Mason Protter Date: Mon, 6 Nov 2023 15:26:18 +0100 Subject: [PATCH 01/14] add `wrap` function which is the safe counterpart to `unsafe_wrap`. --- NEWS.md | 1 + base/array.jl | 28 ++++++++++++++++++++++++++++ base/exports.jl | 1 + test/arrayops.jl | 14 ++++++++++++++ 4 files changed, 44 insertions(+) diff --git a/NEWS.md b/NEWS.md index 1b3cc6438e03f..1fb0d265d494d 100644 --- a/NEWS.md +++ b/NEWS.md @@ -51,6 +51,7 @@ New library functions * `copyuntil(out, io, delim)` and `copyline(out, io)` copy data into an `out::IO` stream ([#48273]). * `eachrsplit(string, pattern)` iterates split substrings right to left. * `Sys.username()` can be used to return the current user's username ([#51897]). +* `wrap(Array, m::Union{MemoryRef{T}, Memory{T}}, dims)` which is the safe counterpart to `unsafe_wrap`. New library features -------------------- diff --git a/base/array.jl b/base/array.jl index 6d59c2f3b41d5..0c79dfecd60ee 100644 --- a/base/array.jl +++ b/base/array.jl @@ -3026,3 +3026,31 @@ intersect(r::AbstractRange, v::AbstractVector) = intersect(v, r) _getindex(v, i) end end + +""" + wrap(Array, m::Union{Memory{T}, MemoryRef{T}}, dims) + +Create an array of size `dims` using `m` as the underlying memory. This can be thought of as a safe version +of [`unsafe_wrap`](@ref) utilizing `Memory` or `MemoryRef` instead of raw pointers. +""" +@propagate_inbounds function wrap(::Type{Array}, m::MemoryRef{T}, dims::NTuple{N, Integer}) where {T, N} + len = length(m.mem) + @boundscheck len >= prod(dims) || invalid_wrap_err(len, dims) + _wrap(Array, m, convert(Tuple{Vararg{Int}}, dims)) +end +@noinline invalid_wrap_err(len, dims) = throw(DimensionMismatch( + "Attempted to wrap a MemoryRef of length $len with an Array of size dims=$dims, which is invalid because prod(dims) = $(prod(dims)) > $len, so that the array would have more elements than the underlying memory can store.")) + +function wrap(::Type{Array}, m::Memory{T}, dims::NTuple{N, Integer}) where {T, N} + wrap(Array, MemoryRef(m), dims) +end +function wrap(::Type{Array}, m::MemoryRef{T}, l::Integer) where {T} + wrap(Array, m, (l,)) +end +function wrap(::Type{Array}, m::Memory{T}, l::Integer) where {T} + wrap(Array, MemoryRef(m), (l,)) +end + +@eval @inline function _wrap(::Type{Array}, m::MemoryRef{T}, dims::NTuple{N, Int}) where {T, N} + $(Expr(:new, :(Array{T, N}), :m, :dims)) +end diff --git a/base/exports.jl b/base/exports.jl index b6f7ea0d6ad35..9244b64704890 100644 --- a/base/exports.jl +++ b/base/exports.jl @@ -458,6 +458,7 @@ export vcat, vec, view, + wrap, zeros, # search, find, match and related functions diff --git a/test/arrayops.jl b/test/arrayops.jl index 2691da4b17154..c80d231b3c93b 100644 --- a/test/arrayops.jl +++ b/test/arrayops.jl @@ -3135,3 +3135,17 @@ end @test c + zero(c) == c end end + +@testset "Wrapping Memory into Arrays" begin + mem = Memory{Int}(undef, 10) .= 1 + memref = MemoryRef(mem) + @test_throws DimensionMismatch wrap(Array, mem, (10, 10)) + @test wrap(Array, mem, (5,)) == ones(Int, 5) + @test wrap(Array, mem, 2) == ones(Int, 2) + @test wrap(Array, memref, 10) == ones(Int, 10) + + # This is broken because length(a::Array{T, N>1}) is currently doing length(a.ref.mem) !!! + @test_broken wrap(Array, memref, (2,2,2)) == ones(Int,2,2,2) + # This works because 5 * 2 happens to equal 10 (the length of mem) + @test wrap(Array, mem, (5, 2)) == ones(Int, 5, 2) +end From e3337392d23d86916e417fde1469a0003845c440 Mon Sep 17 00:00:00 2001 From: Mason Protter Date: Mon, 6 Nov 2023 22:20:58 +0100 Subject: [PATCH 02/14] update from review --- base/array.jl | 19 +++++++++++-------- test/arrayops.jl | 5 +---- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/base/array.jl b/base/array.jl index 0c79dfecd60ee..9c7293af20bb5 100644 --- a/base/array.jl +++ b/base/array.jl @@ -3033,11 +3033,18 @@ end Create an array of size `dims` using `m` as the underlying memory. This can be thought of as a safe version of [`unsafe_wrap`](@ref) utilizing `Memory` or `MemoryRef` instead of raw pointers. """ -@propagate_inbounds function wrap(::Type{Array}, m::MemoryRef{T}, dims::NTuple{N, Integer}) where {T, N} - len = length(m.mem) - @boundscheck len >= prod(dims) || invalid_wrap_err(len, dims) - _wrap(Array, m, convert(Tuple{Vararg{Int}}, dims)) +@eval @propagate_inbounds function wrap(::Type{Array}, ref::MemoryRef{T}, dims::NTuple{N, Integer}) where {T, N} + mem = ref.mem + mem_len = length(mem) + len = Core.checked_dims(dims...) + @boundscheck mem_len >= len || invalid_wrap_err(men_len, dims) + if N > 1 && len !== mem_len + mem = ccall(:jl_genericmemory_slice, Memory{T}, (Any, Ptr{Cvoid}, Int), mem, ref.ptr_or_offset, len) + ref = MemoryRef(mem) + end + $(Expr(:new, :(Array{T, N}), :ref, :dims)) end + @noinline invalid_wrap_err(len, dims) = throw(DimensionMismatch( "Attempted to wrap a MemoryRef of length $len with an Array of size dims=$dims, which is invalid because prod(dims) = $(prod(dims)) > $len, so that the array would have more elements than the underlying memory can store.")) @@ -3050,7 +3057,3 @@ end function wrap(::Type{Array}, m::Memory{T}, l::Integer) where {T} wrap(Array, MemoryRef(m), (l,)) end - -@eval @inline function _wrap(::Type{Array}, m::MemoryRef{T}, dims::NTuple{N, Int}) where {T, N} - $(Expr(:new, :(Array{T, N}), :m, :dims)) -end diff --git a/test/arrayops.jl b/test/arrayops.jl index c80d231b3c93b..b8283fe46d89e 100644 --- a/test/arrayops.jl +++ b/test/arrayops.jl @@ -3143,9 +3143,6 @@ end @test wrap(Array, mem, (5,)) == ones(Int, 5) @test wrap(Array, mem, 2) == ones(Int, 2) @test wrap(Array, memref, 10) == ones(Int, 10) - - # This is broken because length(a::Array{T, N>1}) is currently doing length(a.ref.mem) !!! - @test_broken wrap(Array, memref, (2,2,2)) == ones(Int,2,2,2) - # This works because 5 * 2 happens to equal 10 (the length of mem) + @test wrap(Array, memref, (2,2,2)) == ones(Int,2,2,2) @test wrap(Array, mem, (5, 2)) == ones(Int, 5, 2) end From 75c177d868f7ef24efd24ad8834cc7639f392a5d Mon Sep 17 00:00:00 2001 From: Mason Protter Date: Tue, 7 Nov 2023 16:28:14 +0100 Subject: [PATCH 03/14] fix docs --- base/array.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/base/array.jl b/base/array.jl index 9c7293af20bb5..43c3625554f15 100644 --- a/base/array.jl +++ b/base/array.jl @@ -3033,6 +3033,8 @@ end Create an array of size `dims` using `m` as the underlying memory. This can be thought of as a safe version of [`unsafe_wrap`](@ref) utilizing `Memory` or `MemoryRef` instead of raw pointers. """ +function wrap end + @eval @propagate_inbounds function wrap(::Type{Array}, ref::MemoryRef{T}, dims::NTuple{N, Integer}) where {T, N} mem = ref.mem mem_len = length(mem) From 4cac28bc2487ebc4f64823f39c884c354a793b38 Mon Sep 17 00:00:00 2001 From: Mason Protter Date: Thu, 9 Nov 2023 13:22:11 +0100 Subject: [PATCH 04/14] fix typo --- base/array.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/base/array.jl b/base/array.jl index 43c3625554f15..4da372f15ec45 100644 --- a/base/array.jl +++ b/base/array.jl @@ -3039,7 +3039,7 @@ function wrap end mem = ref.mem mem_len = length(mem) len = Core.checked_dims(dims...) - @boundscheck mem_len >= len || invalid_wrap_err(men_len, dims) + @boundscheck mem_len >= len || invalid_wrap_err(mem_len, dims) if N > 1 && len !== mem_len mem = ccall(:jl_genericmemory_slice, Memory{T}, (Any, Ptr{Cvoid}, Int), mem, ref.ptr_or_offset, len) ref = MemoryRef(mem) From 3833df9a7d803e10d2547211fd2b568f547dd743 Mon Sep 17 00:00:00 2001 From: Mason Protter Date: Sun, 26 Nov 2023 22:20:53 +0100 Subject: [PATCH 05/14] Update base/array.jl Co-authored-by: Jameson Nash --- base/array.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/base/array.jl b/base/array.jl index effe072bb2f31..cc0e4449fffb4 100644 --- a/base/array.jl +++ b/base/array.jl @@ -3051,7 +3051,7 @@ function wrap end mem_len = length(mem) len = Core.checked_dims(dims...) @boundscheck mem_len >= len || invalid_wrap_err(mem_len, dims) - if N > 1 && len !== mem_len + if N != 1 && !(ref === GenericMemoryRef(mem) && len === mem_len) mem = ccall(:jl_genericmemory_slice, Memory{T}, (Any, Ptr{Cvoid}, Int), mem, ref.ptr_or_offset, len) ref = MemoryRef(mem) end From 30d387a69615b0ad452bd9c8ff587581ef297142 Mon Sep 17 00:00:00 2001 From: Mason Protter Date: Sun, 26 Nov 2023 22:21:22 +0100 Subject: [PATCH 06/14] Update base/array.jl Co-authored-by: Jameson Nash --- base/array.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/base/array.jl b/base/array.jl index cc0e4449fffb4..e541b9da01a38 100644 --- a/base/array.jl +++ b/base/array.jl @@ -3058,8 +3058,8 @@ function wrap end $(Expr(:new, :(Array{T, N}), :ref, :dims)) end -@noinline invalid_wrap_err(len, dims) = throw(DimensionMismatch( - "Attempted to wrap a MemoryRef of length $len with an Array of size dims=$dims, which is invalid because prod(dims) = $(prod(dims)) > $len, so that the array would have more elements than the underlying memory can store.")) +@noinline invalid_wrap_err(len, dims, proddims) = throw(DimensionMismatch( + "Attempted to wrap a MemoryRef of length $len with an Array of size dims=$dims, which is invalid because prod(dims) = $proddims > $len, so that the array would have more elements than the underlying memory can store.")) function wrap(::Type{Array}, m::Memory{T}, dims::NTuple{N, Integer}) where {T, N} wrap(Array, MemoryRef(m), dims) From 3e8525369c82d06aec354023b1b52785fb81c1e4 Mon Sep 17 00:00:00 2001 From: Mason Protter Date: Sun, 26 Nov 2023 22:21:52 +0100 Subject: [PATCH 07/14] Update array.jl --- base/array.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/base/array.jl b/base/array.jl index e541b9da01a38..3f39958fe649f 100644 --- a/base/array.jl +++ b/base/array.jl @@ -3050,7 +3050,7 @@ function wrap end mem = ref.mem mem_len = length(mem) len = Core.checked_dims(dims...) - @boundscheck mem_len >= len || invalid_wrap_err(mem_len, dims) + @boundscheck mem_len >= len || invalid_wrap_err(mem_len, dims, len) if N != 1 && !(ref === GenericMemoryRef(mem) && len === mem_len) mem = ccall(:jl_genericmemory_slice, Memory{T}, (Any, Ptr{Cvoid}, Int), mem, ref.ptr_or_offset, len) ref = MemoryRef(mem) From 0e70a6d0c2241a90a772fcb3ef8a3a3d43180e9d Mon Sep 17 00:00:00 2001 From: Mason Protter Date: Sun, 26 Nov 2023 22:25:00 +0100 Subject: [PATCH 08/14] Update array.jl --- base/array.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/base/array.jl b/base/array.jl index 3f39958fe649f..50853404ffbcd 100644 --- a/base/array.jl +++ b/base/array.jl @@ -3050,7 +3050,7 @@ function wrap end mem = ref.mem mem_len = length(mem) len = Core.checked_dims(dims...) - @boundscheck mem_len >= len || invalid_wrap_err(mem_len, dims, len) + @boundscheck mem_len + memoryrefoffset(ref) > len || invalid_wrap_err(mem_len, dims, len) if N != 1 && !(ref === GenericMemoryRef(mem) && len === mem_len) mem = ccall(:jl_genericmemory_slice, Memory{T}, (Any, Ptr{Cvoid}, Int), mem, ref.ptr_or_offset, len) ref = MemoryRef(mem) From 60748b5dd53b756e8cb89805059f791f4ff81eb7 Mon Sep 17 00:00:00 2001 From: Mason Protter Date: Sun, 26 Nov 2023 22:28:19 +0100 Subject: [PATCH 09/14] test with an offset memref --- test/arrayops.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/test/arrayops.jl b/test/arrayops.jl index e38baf5906e9e..3149938543b7a 100644 --- a/test/arrayops.jl +++ b/test/arrayops.jl @@ -3180,4 +3180,10 @@ end @test wrap(Array, memref, 10) == ones(Int, 10) @test wrap(Array, memref, (2,2,2)) == ones(Int,2,2,2) @test wrap(Array, mem, (5, 2)) == ones(Int, 5, 2) + + memref2 = MemoryRef(mem, 2) + @test wrap(Array, mem, (5,)) == ones(Int, 5) + @test wrap(Array, mem, 2) == ones(Int, 2) + @test wrap(Array, memref, (2,2,2)) == ones(Int,2,2,2) + @test wrap(Array, mem, (3, 2)) == ones(Int, 3, 2) end From 3937b28660c429d6bf53998afa836b8afb614f4b Mon Sep 17 00:00:00 2001 From: Mason Protter Date: Sun, 26 Nov 2023 22:41:03 +0100 Subject: [PATCH 10/14] Update arrayops.jl --- test/arrayops.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/arrayops.jl b/test/arrayops.jl index 3149938543b7a..28259b50b9c8f 100644 --- a/test/arrayops.jl +++ b/test/arrayops.jl @@ -3182,8 +3182,8 @@ end @test wrap(Array, mem, (5, 2)) == ones(Int, 5, 2) memref2 = MemoryRef(mem, 2) - @test wrap(Array, mem, (5,)) == ones(Int, 5) - @test wrap(Array, mem, 2) == ones(Int, 2) - @test wrap(Array, memref, (2,2,2)) == ones(Int,2,2,2) - @test wrap(Array, mem, (3, 2)) == ones(Int, 3, 2) + @test wrap(Array, memref2, (5,)) == ones(Int, 5) + @test wrap(Array, memref2, 2) == ones(Int, 2) + @test wrap(Array, memref2, (2,2,2)) == ones(Int,2,2,2) + @test wrap(Array, memref2, (3, 2)) == ones(Int, 3, 2) end From aafc3816641d510914a0830e20d330fb9198f409 Mon Sep 17 00:00:00 2001 From: Mason Protter Date: Sun, 26 Nov 2023 22:55:02 +0100 Subject: [PATCH 11/14] fix memory size checks --- base/array.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/base/array.jl b/base/array.jl index 50853404ffbcd..0c3372b8d68dc 100644 --- a/base/array.jl +++ b/base/array.jl @@ -3048,9 +3048,9 @@ function wrap end @eval @propagate_inbounds function wrap(::Type{Array}, ref::MemoryRef{T}, dims::NTuple{N, Integer}) where {T, N} mem = ref.mem - mem_len = length(mem) + mem_len = length(mem) + 1 - memoryrefoffset(ref) len = Core.checked_dims(dims...) - @boundscheck mem_len + memoryrefoffset(ref) > len || invalid_wrap_err(mem_len, dims, len) + @boundscheck mem_len >= len || invalid_wrap_err(mem_len, dims, len) if N != 1 && !(ref === GenericMemoryRef(mem) && len === mem_len) mem = ccall(:jl_genericmemory_slice, Memory{T}, (Any, Ptr{Cvoid}, Int), mem, ref.ptr_or_offset, len) ref = MemoryRef(mem) From 8fe2ff16c5ad6cd313c7eba8a34b7117eaeddc84 Mon Sep 17 00:00:00 2001 From: Mason Protter Date: Sun, 26 Nov 2023 22:55:49 +0100 Subject: [PATCH 12/14] test for `MemoryRef` offsets causing oob --- test/arrayops.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/arrayops.jl b/test/arrayops.jl index 28259b50b9c8f..6844fc1f1a61b 100644 --- a/test/arrayops.jl +++ b/test/arrayops.jl @@ -3186,4 +3186,6 @@ end @test wrap(Array, memref2, 2) == ones(Int, 2) @test wrap(Array, memref2, (2,2,2)) == ones(Int,2,2,2) @test wrap(Array, memref2, (3, 2)) == ones(Int, 3, 2) + @test_throws DimensionMismatch wrap(Array, memref2, 9) + @test_throws DimensionMismatch wrap(Array, memref2, 10) end From 94d1acdee8779ae27ed022fa63bff9f4f76039cc Mon Sep 17 00:00:00 2001 From: Mason Protter Date: Mon, 27 Nov 2023 08:44:20 +0100 Subject: [PATCH 13/14] fix typo --- test/arrayops.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/arrayops.jl b/test/arrayops.jl index 6844fc1f1a61b..8e33e209ee88b 100644 --- a/test/arrayops.jl +++ b/test/arrayops.jl @@ -3181,7 +3181,7 @@ end @test wrap(Array, memref, (2,2,2)) == ones(Int,2,2,2) @test wrap(Array, mem, (5, 2)) == ones(Int, 5, 2) - memref2 = MemoryRef(mem, 2) + memref2 = MemoryRef(mem, 3) @test wrap(Array, memref2, (5,)) == ones(Int, 5) @test wrap(Array, memref2, 2) == ones(Int, 2) @test wrap(Array, memref2, (2,2,2)) == ones(Int,2,2,2) From 3ddbcda3508405ba74fdbc68db7cd15446c14d02 Mon Sep 17 00:00:00 2001 From: Mason Protter Date: Mon, 27 Nov 2023 14:27:39 +0100 Subject: [PATCH 14/14] Include link to PR --- NEWS.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/NEWS.md b/NEWS.md index 6daa9c26c1d07..f813db02819d7 100644 --- a/NEWS.md +++ b/NEWS.md @@ -61,7 +61,7 @@ New library functions * `copyuntil(out, io, delim)` and `copyline(out, io)` copy data into an `out::IO` stream ([#48273]). * `eachrsplit(string, pattern)` iterates split substrings right to left. * `Sys.username()` can be used to return the current user's username ([#51897]). -* `wrap(Array, m::Union{MemoryRef{T}, Memory{T}}, dims)` which is the safe counterpart to `unsafe_wrap`. +* `wrap(Array, m::Union{MemoryRef{T}, Memory{T}}, dims)` which is the safe counterpart to `unsafe_wrap` ([#52049]). New library features --------------------