diff --git a/docs/src/usage/array.md b/docs/src/usage/array.md index 27c865c12..bd1a9832e 100644 --- a/docs/src/usage/array.md +++ b/docs/src/usage/array.md @@ -119,12 +119,12 @@ Base's convenience functions for generating random numbers are available in Meta ```jldoctest julia> Metal.rand(2) -2-element MtlVector{Float32, Private}: +2-element MtlVector{Float32, Metal.PrivateStorage}: 0.89025915 0.8946847 julia> Metal.randn(Float32, 2, 1) -2×1 MtlMatrix{Float32, Private}: +2×1 MtlMatrix{Float32, Metal.PrivateStorage}: 1.2279074 1.2518331 ``` @@ -138,16 +138,16 @@ standard library: julia> using Random, GPUArrays julia> a = Random.rand(MPS.default_rng(), Float32, 1) -1-element MtlVector{Float32, Private}: +1-element MtlVector{Float32, Metal.PrivateStorage}: 0.89025915 julia> a = Random.rand!(GPUArrays.default_rng(MtlArray), a) -1-element MtlVector{Float32, Private}: +1-element MtlVector{Float32, Metal.PrivateStorage}: 0.0705002 ``` !!! note - `MPSMatrixRandom` functionality requires Metal.jl >= v1.3 + `MPSMatrixRandom` functionality requires Metal.jl >= v2.0 !!! warning `Random.rand!(::MPS.RNG, args...)` and `Random.randn!(::MPS.RNG, args...)` have a framework limitation that requires the byte offset and byte size of the destination array to be a multiple of 4. diff --git a/test/random.jl b/test/random.jl index 40aab582a..8889de8ad 100644 --- a/test/random.jl +++ b/test/random.jl @@ -1,7 +1,3 @@ -using Random -using Metal -using Metal: can_use_mpsrandom - const RAND_TYPES = [Float16, Float32, Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64, UInt64] const RANDN_TYPES = [Float16, Float32] @@ -31,7 +27,7 @@ const OOPLACE_TUPLES = [[(Metal.rand, rand, T) for T in RAND_TYPES]; # specified MPS rng if T != Float16 fill!(A, T(0)) - if can_use_mpsrandom(A) + if Metal.can_use_mpsrandom(A) f(rng, A) @test !iszero(collect(A)) else @@ -51,7 +47,7 @@ const OOPLACE_TUPLES = [[(Metal.rand, rand, T) for T in RAND_TYPES]; # specified MPS rng if T != Float16 fill!(A, T(0)) - if can_use_mpsrandom(A) + if Metal.can_use_mpsrandom(A) f(rng, A) @test Array(A) == fill(1, 0) else @@ -139,7 +135,7 @@ const OOPLACE_TUPLES = [[(Metal.rand, rand, T) for T in RAND_TYPES]; view_A = @view A[idx] # Errors in Julia before crashing whole process - if can_use_mpsrandom(view_A) + if Metal.can_use_mpsrandom(view_A) f(rng, view_A) cpuA = collect(A) @@ -153,7 +149,7 @@ const OOPLACE_TUPLES = [[(Metal.rand, rand, T) for T in RAND_TYPES]; fill!(A, T(0)) idx = 1:51 view_A = @view A[idx] - if can_use_mpsrandom(view_A) + if Metal.can_use_mpsrandom(view_A) f(rng, view_A) cpuA = collect(A)