Skip to content

Commit

Permalink
Specialize Base.reinterpret (#13)
Browse files Browse the repository at this point in the history
  • Loading branch information
green-nsk authored Jul 2, 2022
1 parent 654d42a commit 204fcd9
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 0 deletions.
12 changes: 12 additions & 0 deletions src/unsafe_array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,18 @@ end

Base.deepcopy(A::UnsafeArray) = copyto!(similar(A), A)

function Base.reinterpret(::Type{DST}, A::UnsafeArray{SRC}) where {DST, SRC}
if sizeof(DST) != sizeof(SRC)
sz = size(A)
sz1, rem = divrem(sz[1] * sizeof(SRC), sizeof(DST))
@boundscheck if rem != zero(rem)
throw(ArgumentError("Resulting array would have non-integral first dimension"))
end
UnsafeArray(convert(Ptr{DST}, pointer(A)), (sz1, Base.tail(sz)...))
else
UnsafeArray(convert(Ptr{DST}, pointer(A)), size(A))
end
end

# # Defining Base.unaliascopy results in very bad broadcast performance for
# # some reason, even when it shouldn't be called. By default, unaliascopy
Expand Down
17 changes: 17 additions & 0 deletions test/unsafe_array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -304,4 +304,21 @@ using Random
@test convert(Array, UA) == A
end
end

@testset "reinterpret" begin
test_A_UA(Float32, Val(3)) do A, UA
@test reinterpret(UInt32, UA) == reinterpret(UInt32, A)
@test reinterpret(UInt8, UA) == reinterpret(UInt8, A)
end

# NOTE: this requires sz_max[1] to be divisible by 4
test_A_UA(UInt8, Val(3)) do A, UA
@test reinterpret(UInt32, UA) == reinterpret(UInt32, A)
@test reinterpret(Int8, UA) == reinterpret(Int8, A)
end

A = UInt32[ 1, 2, 3 ]
UA = UnsafeArray(pointer(A), size(A))
@test_throws ArgumentError reinterpret(UInt64, UA)
end
end

0 comments on commit 204fcd9

Please sign in to comment.