From 1dde9783c406fb06947a1e47f7da0caa9fc25794 Mon Sep 17 00:00:00 2001 From: Christian Guinard <28689358+christiangnrd@users.noreply.github.com> Date: Tue, 27 Aug 2024 17:28:44 -0300 Subject: [PATCH] Add wrappers for `MPSMatrixRandom` (#321) Co-authored-by: Tim Besard --- docs/src/usage/array.md | 45 +++++++ lib/mps/MPS.jl | 2 + lib/mps/matrixrandom.jl | 145 ++++++++++++++++++++ lib/mps/random.jl | 109 +++++++++++++++ src/random.jl | 61 +++++++-- test/mps/copy.jl | 2 +- test/random.jl | 288 ++++++++++++++++++++++++++++++++++------ 7 files changed, 606 insertions(+), 46 deletions(-) create mode 100644 lib/mps/matrixrandom.jl create mode 100644 lib/mps/random.jl diff --git a/docs/src/usage/array.md b/docs/src/usage/array.md index 42e27db7b..0a121df75 100644 --- a/docs/src/usage/array.md +++ b/docs/src/usage/array.md @@ -3,6 +3,12 @@ ```@meta DocTestSetup = quote using Metal + using GPUArrays + + import Random + Random.seed!(1) + + Metal.seed!(1) end ``` @@ -106,3 +112,42 @@ julia> Base.mapreducedim!(identity, +, b, a) 1×1 MtlMatrix{Float32, Metal.PrivateStorage}: 6.0 ``` + +## Random numbers + +Base's convenience functions for generating random numbers are available in Metal as well: + +```jldoctest +julia> Metal.rand(2) +2-element MtlVector{Float32, Metal.PrivateStorage}: + 0.89025915 + 0.8946847 + +julia> Metal.randn(Float32, 2, 1) +2×1 MtlMatrix{Float32, Metal.PrivateStorage}: + 1.2279074 + 1.2518331 +``` + +Behind the scenes, these random numbers come from two different generators: one backed by +[Metal Performance Shaders](https://developer.apple.com/documentation/metalperformanceshaders/mpsmatrixrandom?language=objc), +another by using the GPUArrays.jl random methods. Operations on these generators are implemented using methods from the Random +standard library: + +```jldoctest +julia> using Random, GPUArrays + +julia> a = Random.rand(MPS.default_rng(), Float32, 1) +1-element MtlVector{Float32, Metal.PrivateStorage}: + 0.89025915 + +julia> a = Random.rand!(GPUArrays.default_rng(MtlArray), a) +1-element MtlVector{Float32, Metal.PrivateStorage}: + 0.0705002 +``` + +!!! note + `MPSMatrixRandom` functionality requires Metal.jl >= v1.4 + +!!! warning + `Random.rand!(::MPS.RNG, args...)` and `Random.randn!(::MPS.RNG, args...)` have a framework limitation that requires the byte offset and byte size of the destination array to be a multiple of 4. diff --git a/lib/mps/MPS.jl b/lib/mps/MPS.jl index 2bb794187..7266eae9a 100644 --- a/lib/mps/MPS.jl +++ b/lib/mps/MPS.jl @@ -28,10 +28,12 @@ include("kernel.jl") include("images.jl") include("matrix.jl") include("vector.jl") +include("matrixrandom.jl") include("decomposition.jl") include("copy.jl") # integrations +include("random.jl") include("linalg.jl") end diff --git a/lib/mps/matrixrandom.jl b/lib/mps/matrixrandom.jl new file mode 100644 index 000000000..f166159de --- /dev/null +++ b/lib/mps/matrixrandom.jl @@ -0,0 +1,145 @@ +@cenum MPSMatrixRandomDistribution::UInt begin + MPSMatrixRandomDistributionDefault = 1 + MPSMatrixRandomDistributionUniform = 2 + MPSMatrixRandomDistributionNormal = 3 +end + +# +# matrix random descriptor +# + +export MPSMatrixRandomDistributionDescriptor + +@objcwrapper immutable=false MPSMatrixRandomDistributionDescriptor <: NSObject + +@objcproperties MPSMatrixRandomDistributionDescriptor begin + @autoproperty distributionType::MPSMatrixRandomDistribution + @autoproperty maximum::Float32 setter=setMaximum + @autoproperty mean::Float32 setter=setMean + @autoproperty minimum::Float32 setter=setMimimum + @autoproperty standardDeviation::Float32 setter=setStandardDeviation +end + + +function MPSMatrixRandomDefaultDistributionDescriptor() + desc = @objc [MPSMatrixRandomDistributionDescriptor defaultDistributionDescriptor]::id{MPSMatrixRandomDistributionDescriptor} + obj = MPSMatrixRandomDistributionDescriptor(desc) + return obj +end + +# Default constructor +MPSMatrixRandomDistributionDescriptor() = MPSMatrixRandomDefaultDistributionDescriptor() + +function MPSMatrixRandomNormalDistributionDescriptor(mean, standardDeviation) + desc = @objc [MPSMatrixRandomDistributionDescriptor normalDistributionDescriptorWithMean:mean::Float32 + standardDeviation:standardDeviation::Float32]::id{MPSMatrixRandomDistributionDescriptor} + obj = MPSMatrixRandomDistributionDescriptor(desc) + return obj +end + +function MPSMatrixRandomNormalDistributionDescriptor(mean, standardDeviation, minimum, maximum) + desc = @objc [MPSMatrixRandomDistributionDescriptor normalDistributionDescriptorWithMean:mean::Float32 + standardDeviation:standardDeviation::Float32 + minimum:minimum::Float32 + maximum:maximum::Float32]::id{MPSMatrixRandomDistributionDescriptor} + obj = MPSMatrixRandomDistributionDescriptor(desc) + return obj +end + +function MPSMatrixRandomUniformDistributionDescriptor(minimum, maximum) + desc = @objc [MPSMatrixRandomDistributionDescriptor uniformDistributionDescriptorWithMinimum:minimum::Float32 + maximum:maximum::Float32]::id{MPSMatrixRandomDistributionDescriptor} + obj = MPSMatrixRandomDistributionDescriptor(desc) + return obj +end + + +@objcwrapper immutable=false MPSMatrixRandom <: MPSKernel + +@objcproperties MPSMatrixRandom begin + @autoproperty batchSize::NSUInteger + @autoproperty batchStart::NSUInteger + @autoproperty destinationDataType::id{MPSDataType} + @autoproperty distributionType::id{MPSMatrixRandomDistributionDescriptor} +end + +function encode!(cmdbuf::MTLCommandBuffer, kernel::K, destinationMatrix::MPSMatrix) where {K<:MPSMatrixRandom} + @objc [kernel::id{K} encodeToCommandBuffer:cmdbuf::id{MTLCommandBuffer} + destinationMatrix:destinationMatrix::id{MPSMatrix}]::Nothing +end +function encode!(cmdbuf::MTLCommandBuffer, kernel::K, destinationVector::MPSVector) where {K<:MPSMatrixRandom} + @objc [kernel::id{K} encodeToCommandBuffer:cmdbuf::id{MTLCommandBuffer} + destinationVector:destinationVector::id{MPSVector}]::Nothing +end + +@objcwrapper immutable=false MPSMatrixRandomMTGP32 <: MPSMatrixRandom +@objcwrapper immutable=false MPSMatrixRandomPhilox <: MPSMatrixRandom + +for R in [:MPSMatrixRandomMTGP32, :MPSMatrixRandomPhilox] + @eval begin + function $R(device) + kernel = @objc [$R alloc]::id{$R} + obj = $R(kernel) + finalizer(release, obj) + @objc [obj::id{$R} initWithDevice:device::id{MTLDevice}]::id{$R} + return obj + end + function $R(device, destinationDataType, seed) + kernel = @objc [$R alloc]::id{$R} + obj = $R(kernel) + finalizer(release, obj) + @objc [obj::id{$R} initWithDevice:device::id{MTLDevice} + destinationDataType:destinationDataType::MPSDataType + seed:seed::NSUInteger]::id{$R} + return obj + end + function $R(device, destinationDataType, seed, distributionDescriptor) + kernel = @objc [$R alloc]::id{$R} + obj = $R(kernel) + finalizer(release, obj) + @objc [obj::id{$R} initWithDevice:device::id{MTLDevice} + destinationDataType:destinationDataType::MPSDataType + seed:seed::NSUInteger + distributionDescriptor:distributionDescriptor::id{MPSMatrixRandomDistributionDescriptor}]::id{$R} + return obj + end + end +end + +synchronize_state(kern::MPSMatrixRandomMTGP32, cmdbuf::MTLCommandBuffer) = + @objc [obj::id{MPSMatrixRandomMTGP32} synchronizeStateOnCommandBuffer:cmdbuf::id{MTLCommandBuffer}]::Nothing + + +@inline function _mpsmat_rand!(randkern::MPSMatrixRandom, dest::MtlArray{T}, ::Type{T2}; + queue::MTLCommandQueue = global_queue(device()), + async::Bool=false) where {T,T2} + byteoffset = dest.offset * sizeof(T) + bytesize = sizeof(dest) + + # Even though `append_copy`` seems to work with any size or offset values, the documentation at + # https://developer.apple.com/documentation/metal/mtlblitcommandencoder/1400767-copyfrombuffer?language=objc + # mentions that both must be multiples of 4 bytes in MacOS so error when they are not + (bytesize % 4 == 0) || error(lazy"Destination buffer bytesize ($(bytesize)) must be a multiple of 4.") + (byteoffset % 4 == 0) || error(lazy"Destination buffer offset ($(byteoffset)) must be a multiple of 4.") + + cmdbuf = if bytesize % 16 == 0 && dest.offset == 0 + MTLCommandBuffer(queue) do cmdbuf + vecDesc = MPSVectorDescriptor(bytesize ÷ sizeof(T2), T2) + mpsdest = MPSVector(dest, vecDesc) + encode!(cmdbuf, randkern, mpsdest) + end + else + MTLCommandBuffer(queue) do cmdbuf + len = UInt(ceil(bytesize / sizeof(T2)) * 4) + vecDesc = MPSVectorDescriptor(len, T2) + tempVec = MPSTemporaryVector(cmdbuf, vecDesc) + encode!(cmdbuf, randkern, tempVec) + MTLBlitCommandEncoder(cmdbuf) do enc + MTL.append_copy!(enc, dest.data[], byteoffset, tempVec.data, tempVec.offset, bytesize) + end + end + end + + async || wait_completed(cmdbuf) + return +end diff --git a/lib/mps/random.jl b/lib/mps/random.jl new file mode 100644 index 000000000..81ce58585 --- /dev/null +++ b/lib/mps/random.jl @@ -0,0 +1,109 @@ +using Random +using Metal: DefaultStorageMode + +""" + MPS.RNG() + +A random number generator using `rand()` in a device kernel. +""" +mutable struct RNG <: AbstractRNG + device::MTLDevice + uniformInteger::MPSMatrixRandomPhilox + uniformFloat32::MPSMatrixRandomPhilox + normalFloat32::MPSMatrixRandomPhilox +end + + +make_seed() = Base.rand(RandomDevice(), UInt) + +function RNG(device::MTLDevice, seed::Integer) + seed = seed%UInt + RNG(device, + MPSMatrixRandomPhilox(device, UInt32, seed, MPSMatrixRandomDefaultDistributionDescriptor()), + MPSMatrixRandomPhilox(device, Float32, seed, MPSMatrixRandomUniformDistributionDescriptor(0, 1)), + MPSMatrixRandomPhilox(device, Float32, seed, MPSMatrixRandomNormalDistributionDescriptor(0, 1)),) +end +@autoreleasepool RNG(seed::Integer) = RNG(device(), seed) +RNG(device::MTLDevice) = RNG(device, make_seed()) + +@autoreleasepool RNG() = RNG(device(), make_seed()) + +Base.copy(rng::RNG) = RNG(copy(rng.device), copy(rng.uniformInteger), copy(rng.uniformFloat32), copy(rng.normalFloat32)) + +@autoreleasepool function Random.seed!(rng::RNG, seed::Integer) + rng.uniformInteger = MPSMatrixRandomPhilox(rng.device, UInt32, seed, MPSMatrixRandomDefaultDistributionDescriptor()) + rng.uniformFloat32 = MPSMatrixRandomPhilox(rng.device, Float32, seed, MPSMatrixRandomUniformDistributionDescriptor(0, 1)) + rng.normalFloat32 = MPSMatrixRandomPhilox(rng.device, Float32, seed, MPSMatrixRandomNormalDistributionDescriptor(0, 1)) + return rng +end + +Random.seed!(rng::RNG) = Random.seed!(rng, make_seed()) + +const GLOBAL_RNGs = Dict{MTLDevice,MPS.RNG}() +@autoreleasepool function default_rng() + dev = device() + get!(GLOBAL_RNGs, dev) do + RNG(dev) + end +end + +const UniformTypes = [Float32,UInt8,Int8,UInt16,Int16,UInt32,Int32,UInt64,Int64] +const UniformType = Union{[Type{T} for T in UniformTypes]...} +const UniformArray = MtlArray{<:Union{Float32,UInt8,Int8,UInt16,Int16,UInt32,Int32,UInt64,Int64}} +@autoreleasepool function Random.rand!(rng::RNG, A::MtlArray{T}) where {T<:Union{UInt8,Int8,UInt16,Int16,UInt32,Int32,UInt64,Int64}} + isempty(A) && return A + _mpsmat_rand!(rng.uniformInteger, A, UInt32) + return A +end + +@autoreleasepool function Random.rand!(rng::RNG, A::MtlArray{Float32}) + isempty(A) && return A + _mpsmat_rand!(rng.uniformFloat32, A, Float32) + return A +end + +const NormalType = Type{Float32} +const NormalArray = MtlArray{<:Float32} +@autoreleasepool function Random.randn!(rng::RNG, A::MtlArray{Float32}) + isempty(A) && return A + _mpsmat_rand!(rng.normalFloat32, A, Float32) + return A +end + +# CPU arrays +function Random.rand!(rng::RNG, A::AbstractArray{T,N}) where {T <: Union{UniformTypes...}, N} + isempty(A) && return A + B = MtlArray{T,N,SharedStorage}(undef, size(A)) + rand!(rng, B) + copyto!(A, unsafe_wrap(Array{T},B)) + return A +end +function Random.randn!(rng::RNG, A::AbstractArray{T,N}) where {T <: Float32, N} + isempty(A) && return A + B = MtlArray{T,N,SharedStorage}(undef, size(A)) + randn!(rng, B) + copyto!(A, unsafe_wrap(Array{T},B)) + return A +end + +# Out of place +Random.rand(rng::RNG, T::UniformType, dims::Dims; storage=DefaultStorageMode) = + Random.rand!(rng, MtlArray{T,length(dims),storage}(undef, dims...)) +Random.randn(rng::RNG, T::NormalType, dims::Dims; storage=DefaultStorageMode) = + Random.randn!(rng, MtlArray{T,length(dims),storage}(undef, dims...)) + +# support all dimension specifications +Random.rand(rng::RNG, T::UniformType, dim1::Integer, dims::Integer...; storage=DefaultStorageMode) = + Random.rand!(rng, MtlArray{T,length(dims) + 1,storage}(undef, dim1, dims...)) +Random.randn(rng::RNG, T::NormalType, dim1::Integer, dims::Integer...; storage=DefaultStorageMode) = + Random.randn!(rng, MtlArray{T,length(dims) + 1,storage}(undef, dim1, dims...)) + +# untyped out-of-place +Random.rand(rng::RNG, dim1::Integer, dims::Integer...; storage=DefaultStorageMode) = + Random.rand!(rng, MtlArray{Float32,length(dims) + 1,storage}(undef, dim1, dims...)) +Random.randn(rng::RNG, dim1::Integer, dims::Integer...; storage=DefaultStorageMode) = + Random.randn!(rng, MtlArray{Float32,length(dims) + 1,storage}(undef, dim1, dims...)) + +# scalars +Random.rand(rng::RNG, T::UniformType=Float32; storage=SharedStorage) = rand(rng, T, 4; storage)[1] +Random.randn(rng::RNG, T::NormalType=Float32; storage=SharedStorage) = randn(rng, T, 4; storage)[1] diff --git a/src/random.jl b/src/random.jl index 81cc48c00..acb456766 100644 --- a/src/random.jl +++ b/src/random.jl @@ -1,24 +1,69 @@ using Random +using ..MPS: MPSVector, _mpsmat_rand!, MPSMatrixRandomUniformDistributionDescriptor, + MPSMatrixRandomNormalDistributionDescriptor gpuarrays_rng() = GPUArrays.default_rng(MtlArray) +mpsrand_rng() = MPS.default_rng() # GPUArrays in-place Random.rand!(A::MtlArray) = Random.rand!(gpuarrays_rng(), A) Random.randn!(A::MtlArray) = Random.randn!(gpuarrays_rng(), A) +@inline function can_use_mpsrandom(A::MtlArray{T}) where {T} + return A.offset * sizeof(T) % 4 == 0 && sizeof(A) % 4 == 0 +end + +# Use MPS random functionality where possible +function Random.rand!(A::MPS.UniformArray) + rng = can_use_mpsrandom(A) ? mpsrand_rng() : gpuarrays_rng() + return Random.rand!(rng, A) +end +function Random.randn!(A::MPS.NormalArray) + rng = can_use_mpsrandom(A) ? mpsrand_rng() : gpuarrays_rng() + return Random.randn!(rng, A) +end + # GPUArrays out-of-place -rand(T::Type, dims::Dims; storage=DefaultStorageMode) = Random.rand!(MtlArray{T,length(dims),storage}(undef, dims...)) -randn(T::Type, dims::Dims; storage=DefaultStorageMode, kwargs...) = Random.randn!(MtlArray{T,length(dims),storage}(undef, dims...); kwargs...) +function rand(T::MPS.UniformType, dims::Dims; storage=DefaultStorageMode) + rng = prod(dims) * sizeof(T) % 4 == 0 ? mpsrand_rng() : gpuarrays_rng() + return Random.rand!(rng, MtlArray{T,length(dims),storage}(undef, dims...)) +end +function randn(T::MPS.NormalType, dims::Dims; storage=DefaultStorageMode) + rng = prod(dims) * sizeof(T) % 4 == 0 ? mpsrand_rng() : gpuarrays_rng() + return Random.randn!(rng, MtlArray{T,length(dims),storage}(undef, dims...)) +end +rand(T::Type, dims::Dims; storage=DefaultStorageMode) = + Random.rand!(gpuarrays_rng(), MtlArray{T,length(dims),storage}(undef, dims...)) +randn(T::Type, dims::Dims; storage=DefaultStorageMode) = + Random.randn!(gpuarrays_rng(), MtlArray{T,length(dims),storage}(undef, dims...)) # support all dimension specifications +function rand(T::MPS.UniformType, dim1::Integer, dims::Integer...; storage=DefaultStorageMode) + rng = (dim1 * prod(dims) * sizeof(T)) % 4 == 0 ? mpsrand_rng() : gpuarrays_rng() + return Random.rand!(rng, MtlArray{T,length(dims) + 1,storage}(undef, dim1, dims...)) +end +function randn(T::MPS.NormalType, dim1::Integer, dims::Integer...; storage=DefaultStorageMode) + rng = (dim1 * prod(dims) * sizeof(T)) % 4 == 0 ? mpsrand_rng() : gpuarrays_rng() + return Random.randn!(rng, MtlArray{T,length(dims) + 1,storage}(undef, dim1, dims...)) +end + rand(T::Type, dim1::Integer, dims::Integer...; storage=DefaultStorageMode) = - Random.rand!(MtlArray{T,length(dims)+1,storage}(undef, dim1, dims...)) -randn(T::Type, dim1::Integer, dims::Integer...; storage=DefaultStorageMode, kwargs...) = - Random.randn!(MtlArray{T,length(dims)+1,storage}(undef, dim1, dims...); kwargs...) + Random.rand!(gpuarrays_rng(), MtlArray{T,length(dims) + 1,storage}(undef, dim1, dims...)) +randn(T::Type, dim1::Integer, dims::Integer...; storage=DefaultStorageMode) = + Random.randn!(gpuarrays_rng(), MtlArray{T,length(dims) + 1,storage}(undef, dim1, dims...)) # untyped out-of-place -rand(dim1::Integer, dims::Integer...; storage=DefaultStorageMode) = Random.rand!(MtlArray{Float32,length(dims)+1,storage}(undef, dim1, dims...)) -randn(dim1::Integer, dims::Integer...; storage=DefaultStorageMode, kwargs...) = Random.randn!(MtlArray{Float32,length(dims)+1,storage}(undef, dim1, dims...); kwargs...) +rand(dim1::Integer, dims::Integer...; storage=DefaultStorageMode) = + Random.rand!(mpsrand_rng(), MtlArray{Float32,length(dims) + 1,storage}(undef, dim1, dims...)) +randn(dim1::Integer, dims::Integer...; storage=DefaultStorageMode) = + Random.randn!(mpsrand_rng(), MtlArray{Float32,length(dims) + 1,storage}(undef, dim1, dims...)) + +# scalars +rand(T::Type=Float32; storage=SharedStorage) = rand(T, 4; storage)[1] +randn(T::Type=Float32; storage=SharedStorage) = randn(T, 4; storage)[1] # seeding -seed!(seed=Base.rand(UInt64)) = Random.seed!(gpuarrays_rng(), seed) +function seed!(seed=Base.rand(UInt64)) + Random.seed!(gpuarrays_rng(), seed) + Random.seed!(mpsrand_rng(), seed) +end diff --git a/test/mps/copy.jl b/test/mps/copy.jl index ac20f533e..3c3f2ea15 100644 --- a/test/mps/copy.jl +++ b/test/mps/copy.jl @@ -33,7 +33,7 @@ end Ts = Ts[.!(Ts .<: IGNORE_UNION)] @testset "$T" for T in Ts for dim in ((16,16), (10,500), (500,10), (256,512)) - srcMat = Metal.rand(T, dim) + srcMat = MtlArray(rand(T, dim)) dstMat = copytest(srcMat, false, false) @test dstMat == srcMat diff --git a/test/random.jl b/test/random.jl index 89c771bca..608f03b08 100644 --- a/test/random.jl +++ b/test/random.jl @@ -1,39 +1,253 @@ -using Random - -@testset "rand" begin - -# in-place -for (f,T) in ((rand!,Float16), - (rand!,Float32), - (randn!,Float16), - (randn!,Float32)), - d in (2, (2,2), (2,2,2), 3, (3,3), (3,3,3)) - A = MtlArray{T}(undef, d) - fill!(A, T(0)) - f(A) - @test !iszero(collect(A)) -end - -# out-of-place, with implicit type -for (f,T) in ((Metal.rand,Float32), (Metal.randn,Float32)), - args in ((2,), (2, 2), (3,), (3, 3)) - A = f(args...) - @test eltype(A) == T -end - -# out-of-place, with type specified -for (f,T) in ((Metal.rand,Float32), (Metal.randn,Float32), - (rand,Float32), (randn,Float32)), - args in ((T, 2), (T, 2, 2), (T, (2, 2)), (T, 3), (T, 3, 3), (T, (3, 3))) - A = f(args...) - @test eltype(A) == T -end - -## seeding -Metal.seed!(1) -a = Metal.rand(Int32, 1) -Metal.seed!(1) -b = Metal.rand(Int32, 1) -@test iszero(collect(a) - collect(b)) +const RAND_TYPES = [Float16, Float32, Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64, + UInt64] +const RANDN_TYPES = [Float16, Float32] +const INPLACE_TUPLES = [[(rand!, T) for T in RAND_TYPES]; + [(randn!, T) for T in RANDN_TYPES]] +const OOPLACE_TUPLES = [[(Metal.rand, rand, T) for T in RAND_TYPES]; + [(Metal.randn, rand, T) for T in RANDN_TYPES]] +@testset "random" begin + # in-place + @testset "in-place" begin + rng = Metal.MPS.RNG() + + @testset "$f with $T" for (f, T) in INPLACE_TUPLES + @testset "$d" for d in (1, 3, (3, 3), (3, 3, 3), 16, (16, 16), (16, 16, 16), (1000,), (1000,1000)) + A = MtlArray{T}(undef, d) + + # default_rng + fill!(A, T(0)) + f(A) + @test !iszero(collect(A)) + + # specified MPS rng + if T != Float16 + fill!(A, T(0)) + if Metal.can_use_mpsrandom(A) + f(rng, A) + @test !iszero(collect(A)) + else + @test_throws "Destination buffer" f(rng, A) + end + end + end + + @testset "0" begin + A = MtlArray{T}(undef, 0) + + # default_rng + f(A) + @test A isa MtlArray{T,1} + @test Array(A) == fill(1, 0) + + # specified MPS rng + if T != Float16 + fill!(A, T(0)) + if Metal.can_use_mpsrandom(A) + f(rng, A) + @test Array(A) == fill(1, 0) + else + @test_throws "Destination buffer" f(rng, A) + end + end + end + end + end + + # in-place contiguous views + @testset "in-place for views" begin + @testset "$f with $T" for (f, T) in INPLACE_TUPLES + alen = 100 + A = MtlArray{T}(undef, alen) + function test_view!(X::MtlArray{T}, idx) where {T} + fill!(X, T(0)) + view_X = @view X[idx] + f(view_X) + cpuX = collect(X) + not_zero_in_view = !iszero(cpuX[idx]) + rest_of_array_untouched = iszero(cpuX[1:alen .∉ Ref(idx)]) + return not_zero_in_view, rest_of_array_untouched + end + + # Test when view offset is 0 and buffer size not multiple of 4 + @testset "Off == 0, buf % 4 != 0" begin + not_zero_in_view, rest_of_array_untouched = test_view!(A, 1:51) + @test not_zero_in_view + @test rest_of_array_untouched + end + + # Test when view offset is 0 and buffer size is multiple of 16 + @testset "Off == 0, buf % 16 == 0" begin + not_zero_in_view, rest_of_array_untouched = test_view!(A, 1:32) + @test not_zero_in_view + @test rest_of_array_untouched + end + + # Test when view offset is 0 and buffer size is multiple of 4 + @testset "Off == 0, buf % 4 == 0" begin + not_zero_in_view, rest_of_array_untouched = test_view!(A, 1:36) + @test not_zero_in_view + @test rest_of_array_untouched + end + + # Test when view offset is not 0 nor multiple of 4 and buffer size not multiple of 16 + @testset "Off != 0, buf % 4 != 0" begin + not_zero_in_view, rest_of_array_untouched = test_view!(A, 3:51) + @test not_zero_in_view + @test rest_of_array_untouched + end + + # Test when view offset is multiple of 4 and buffer size not multiple of 4 + @testset "Off % 4 == 0, buf % 4 != 0" begin + not_zero_in_view, rest_of_array_untouched = test_view!(A, 17:51) + @test not_zero_in_view + @test rest_of_array_untouched + end + + # Test when view offset is multiple of 4 and buffer size multiple of 16 + @testset "Off % 4 == 0, buf % 16 == 0" begin + not_zero_in_view, rest_of_array_untouched = test_view!(A, 9:40) + @test not_zero_in_view + @test rest_of_array_untouched + end + + # Test when view offset is multiple of 4 and buffer size multiple of 4 + @testset "Off % 16 == 0, buf % 4 == 0" begin + not_zero_in_view, rest_of_array_untouched = test_view!(A, 9:32) + @test not_zero_in_view + @test rest_of_array_untouched + end + end + + # Test when views try to use rand!(rng, args..) + @testset "MPS.RNG with views" begin + rng = Metal.MPS.RNG() + @testset "$f with $T" for (f, T) in ((randn!, Float32),(rand!, Int64),(rand!, Float32), (rand!, UInt16), (rand!,Int8)) + A = MtlArray{T}(undef, 100) + + ## Offset > 0 + fill!(A, T(0)) + idx = 4:50 + view_A = @view A[idx] + + # Errors in Julia before crashing whole process + if Metal.can_use_mpsrandom(view_A) + f(rng, view_A) + + cpuA = collect(A) + @test !iszero(cpuA[idx]) + @test iszero(cpuA[1:100 .∉ Ref(idx)]) broken=(sizeof(view_A) % 4 != 0) + else + @test_throws "Destination buffer" f(rng, view_A) + end + + ## Offset == 0 + fill!(A, T(0)) + idx = 1:51 + view_A = @view A[idx] + if Metal.can_use_mpsrandom(view_A) + f(rng, view_A) + + cpuA = collect(A) + @test !iszero(cpuA[idx]) + @test iszero(cpuA[1:100 .∉ Ref(idx)]) + else + @test_throws "Destination buffer" f(rng, view_A) + end + end + end + end + # out-of-place + @testset "out-of-place" begin + @testset "$fr with implicit type" for (fm, fr, T) in + ((Metal.rand, rand, Float32), (Metal.randn, rand, Float32)) + rng = Metal.MPS.RNG() + @testset "args" for args in ((0,), (1,), (3,), (3, 3), (16,), (16, 16), (1000,), (1000,1000)) + # default_rng + A = fm(args...) + @test eltype(A) == T + + # specified MPS rng + B = fr(rng, args...) + @test eltype(B) == T + end + + @testset "scalar" begin + a = fm() + @test typeof(a) == T + b = fr(rng) + @test typeof(b) == T + end + end + + # out-of-place, with type specified + @testset "$fr with $T" for (fm, fr, T) in OOPLACE_TUPLES + rng = Metal.MPS.RNG() + @testset "$args" for args in ((T, 0), + (T, 1), + (T, 3), + (T, 3, 3), + (T, (3, 3)), + (T, 16), + (T, 16, 16), + (T, (16, 16)), + (T, 1000), + (T, 1000, 1000),) + # default_rng + A = fm(args...) + @test eltype(A) == T + + # specified MPS rng + if T != Float16 + if length(zeros(args...)) * sizeof(T) % 4 == 0 + B = fr(rng, args...) + @test eltype(B) == T + else + @test_throws "Destination buffer" fr(rng, args...) + end + end + end + + @testset "scalar" begin + a = fm(T) + @test typeof(a) == T + b = fr(rng, T) + @test typeof(b) == T + end + end + end + + ## CPU Arrays with MPS rng + @testset "CPU Arrays" begin + mps_tuples = filter(INPLACE_TUPLES) do tup + tup[2] != Float16 + end + rng = Metal.MPS.RNG() + @testset "$f with $T" for (f, T) in mps_tuples + @testset "$d" for d in (1, 3, (3, 3), (3, 3, 3), 16, (16, 16), (16, 16, 16), (1000,), (1000,1000)) + A = zeros(T, d) + if (prod(d) * sizeof(T)) % 4 == 0 + f(rng, A) + @test !iszero(collect(A)) + else + @test_throws "Destination buffer" f(rng, A) + end + end + end + end + + ## seeding + @testset "Seeding $L" for (f,T,L) in [(Metal.rand,UInt32,"Uniform Integers MPS"), + (Metal.rand,Float32,"Uniform Float32 MPS"), + (Metal.randn,Float32,"Normal Float32 MPS"), + (Metal.randn,Float16,"Float16 GPUArrays")] + @testset "$d" for d in (1, 3, (3, 3, 3), 16, (16, 16), (16, 16, 16), (1000,), (1000,1000)) + Metal.seed!(1) + a = f(T, d) + Metal.seed!(1) + b = f(T, d) + # TODO: Remove broken parameter once https://github.com/JuliaGPU/GPUArrays.jl/issues/530 is fixed + @test Array(a) == Array(b) broken = (T == Float16 && d == (1000,1000)) + end + end end # testset