Skip to content

Commit

Permalink
Update docs and clean up tests
Browse files Browse the repository at this point in the history
  • Loading branch information
christiangnrd authored and maleadt committed Aug 27, 2024
1 parent 0220ee5 commit 05f4752
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 13 deletions.
10 changes: 5 additions & 5 deletions docs/src/usage/array.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,12 +119,12 @@ Base's convenience functions for generating random numbers are available in Meta

```jldoctest
julia> Metal.rand(2)
2-element MtlVector{Float32, Private}:
2-element MtlVector{Float32, Metal.PrivateStorage}:
0.89025915
0.8946847
julia> Metal.randn(Float32, 2, 1)
2×1 MtlMatrix{Float32, Private}:
2×1 MtlMatrix{Float32, Metal.PrivateStorage}:
1.2279074
1.2518331
```
Expand All @@ -138,16 +138,16 @@ standard library:
julia> using Random, GPUArrays
julia> a = Random.rand(MPS.default_rng(), Float32, 1)
1-element MtlVector{Float32, Private}:
1-element MtlVector{Float32, Metal.PrivateStorage}:
0.89025915
julia> a = Random.rand!(GPUArrays.default_rng(MtlArray), a)
1-element MtlVector{Float32, Private}:
1-element MtlVector{Float32, Metal.PrivateStorage}:
0.0705002
```

!!! note
`MPSMatrixRandom` functionality requires Metal.jl >= v1.3
`MPSMatrixRandom` functionality requires Metal.jl >= v2.0

!!! warning
`Random.rand!(::MPS.RNG, args...)` and `Random.randn!(::MPS.RNG, args...)` have a framework limitation that requires the byte offset and byte size of the destination array to be a multiple of 4.
12 changes: 4 additions & 8 deletions test/random.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
using Random
using Metal
using Metal: can_use_mpsrandom

const RAND_TYPES = [Float16, Float32, Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64,
UInt64]
const RANDN_TYPES = [Float16, Float32]
Expand Down Expand Up @@ -31,7 +27,7 @@ const OOPLACE_TUPLES = [[(Metal.rand, rand, T) for T in RAND_TYPES];
# specified MPS rng
if T != Float16
fill!(A, T(0))
if can_use_mpsrandom(A)
if Metal.can_use_mpsrandom(A)
f(rng, A)
@test !iszero(collect(A))
else
Expand All @@ -51,7 +47,7 @@ const OOPLACE_TUPLES = [[(Metal.rand, rand, T) for T in RAND_TYPES];
# specified MPS rng
if T != Float16
fill!(A, T(0))
if can_use_mpsrandom(A)
if Metal.can_use_mpsrandom(A)
f(rng, A)
@test Array(A) == fill(1, 0)
else
Expand Down Expand Up @@ -139,7 +135,7 @@ const OOPLACE_TUPLES = [[(Metal.rand, rand, T) for T in RAND_TYPES];
view_A = @view A[idx]

# Errors in Julia before crashing whole process
if can_use_mpsrandom(view_A)
if Metal.can_use_mpsrandom(view_A)
f(rng, view_A)

cpuA = collect(A)
Expand All @@ -153,7 +149,7 @@ const OOPLACE_TUPLES = [[(Metal.rand, rand, T) for T in RAND_TYPES];
fill!(A, T(0))
idx = 1:51
view_A = @view A[idx]
if can_use_mpsrandom(view_A)
if Metal.can_use_mpsrandom(view_A)
f(rng, view_A)

cpuA = collect(A)
Expand Down

0 comments on commit 05f4752

Please sign in to comment.