Skip to content

Commit

Permalink
fix adjoint test more
Browse files Browse the repository at this point in the history
  • Loading branch information
GiggleLiu committed Apr 20, 2022
1 parent d0dc91b commit a436697
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 5 deletions.
20 changes: 16 additions & 4 deletions src/cueinsum.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
using .CUDA

const CUDAArrayTypes{T,N} = Union{LinearAlgebra.Transpose{T,<:CuArray{T,N}}, DenseCuArray{T,N}, LinearAlgebra.Adjoint{T,<:CuArray{T,N}}}
_unwrap(x::LinearAlgebra.Adjoint{T,<:CuArray{T}}) where T = CuArray(x)
_unwrap(x::LinearAlgebra.Transpose{T,<:CuArray{T}}) where T = CuArray(x)
_unwrap(x::CuArray) = x

asarray(x, arr::CuArray) where T = CuArray(fill(x, ()))
asarray(x::AbstractArray, y::CuArray) = x
asscalar(x::DenseCuArray) = Array(x)[]
Expand Down Expand Up @@ -81,10 +86,6 @@ end

Base.ndims(::Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{0}}) = 0

const CUDAArrayTypes{T,N} = Union{LinearAlgebra.Transpose{T,<:CuArray{T,N}}, DenseCuArray{T,N}, LinearAlgebra.Adjoint{T,<:CuArray{T,N}}}
_unwrap(x::LinearAlgebra.Adjoint{T,<:CuArray{T}}) where T = CuArray(x)
_unwrap(x::LinearAlgebra.Transpose{T,<:CuArray{T}}) where T = CuArray(x)
_unwrap(x::CuArray) = x
function einsum(neinsum::NestedEinsum, @nospecialize(xs::NTuple{N,CUDAArrayTypes} where N), size_dict::Dict; active_free=false)
# do not use map because the static overhead is too large
# do not use `setindex!` because we need to make the AD work
Expand All @@ -99,4 +100,15 @@ function einsum(neinsum::NestedEinsum, @nospecialize(xs::NTuple{N,CUDAArrayTypes
return res
end

# to dispatch Adjoint correctly
@generated function einsum(code::StaticEinCode{ixs, iy}, xs::NTuple{N,CUDAArrayTypes} where N, size_dict::Dict{LT}) where {LT, ixs, iy}
rule = match_rule(ixs, iy)
:(einsum($rule, $ixs, $iy, _unwrap.(xs), size_dict))
end

function einsum(code::DynamicEinCode, @nospecialize(xs::NTuple{N,CUDAArrayTypes} where N), size_dict::Dict)
rule = match_rule(getixs(code), getiy(code))
einsum(rule, getixs(code), getiy(code), _unwrap.(xs), size_dict)
end

@info("OMEinsum loaded the CUDA module successfully")
4 changes: 3 additions & 1 deletion test/cueinsum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ end
end

@testset "adjoint dispatch" begin
u = CUDA.rand(2,2); A = CUDA.rand(2,2,3);
u = CUDA.rand(2,2); A = CUDA.rand(2,2,2);
@test Array(ein"(ip,pql),qj -> ijl"(u', A, u)) ein"(ip,pql),qj -> ijl"(Array(CuArray(u')), Array(A), Array(u))
@test Array(DynamicEinCode(ein"mk, ijk -> ijm")(u', A)) DynamicEinCode(ein"mk, ijk -> ijm")(Array(u'), Array(A))
@test Array(ein"mk, ijk -> ijm"(u', A)) DynamicEinCode(ein"mk, ijk -> ijm")(Array(u'), Array(A))
end

0 comments on commit a436697

Please sign in to comment.