Skip to content

Commit

Permalink
Merge branch 'master' into v1
Browse files Browse the repository at this point in the history
  • Loading branch information
lkdvos committed Jul 16, 2024
2 parents 332c70c + 6265c35 commit 61ddbe8
Show file tree
Hide file tree
Showing 19 changed files with 715 additions and 389 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "TensorKit"
uuid = "07d1fe3e-3e46-537d-9eac-e9e13d0d4cec"
authors = ["Jutho Haegeman"]
version = "0.12.3"
version = "0.12.5"

[deps]
HalfIntegers = "f0d1745a-41c9-11e9-1dd9-e5d34d218721"
Expand All @@ -17,9 +17,11 @@ WignerSymbols = "9f57e263-0b3d-5e2e-b1be-24f2bb48858b"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"

[extensions]
TensorKitChainRulesCoreExt = "ChainRulesCore"
TensorKitFiniteDifferencesExt = "FiniteDifferences"

[compat]
Aqua = "0.6, 0.7, 0.8"
Expand Down
28 changes: 28 additions & 0 deletions ext/TensorKitChainRulesCoreExt/TensorKitChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
module TensorKitChainRulesCoreExt

using TensorOperations
using VectorInterface
using TensorKit
using ChainRulesCore
using LinearAlgebra
using TupleTools

import TensorOperations as TO
using TensorOperations: Backend, promote_contract
using VectorInterface: promote_scale, promote_add

ext = @static if isdefined(Base, :get_extension)
Base.get_extension(TensorOperations, :TensorOperationsChainRulesCoreExt)
else
TensorOperations.TensorOperationsChainRulesCoreExt
end
const _conj = ext._conj
const trivtuple = ext.trivtuple

include("utility.jl")
include("constructors.jl")
include("linalg.jl")
include("tensoroperations.jl")
include("factorizations.jl")

end
49 changes: 49 additions & 0 deletions ext/TensorKitChainRulesCoreExt/constructors.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
@non_differentiable TensorKit.TensorMap(f::Function, storagetype, cod, dom)
@non_differentiable TensorKit.id(args...)
@non_differentiable TensorKit.isomorphism(args...)
@non_differentiable TensorKit.isometry(args...)
@non_differentiable TensorKit.unitary(args...)

function ChainRulesCore.rrule(::Type{<:TensorMap}, d::DenseArray, args...; kwargs...)
function TensorMap_pullback(Δt)
∂d = convert(Array, unthunk(Δt))
return NoTangent(), ∂d, ntuple(_ -> NoTangent(), length(args))...
end
return TensorMap(d, args...; kwargs...), TensorMap_pullback
end

function ChainRulesCore.rrule(::typeof(Base.copy), t::AbstractTensorMap)
copy_pullback(Δt) = NoTangent(), Δt
return copy(t), copy_pullback
end

function ChainRulesCore.rrule(::typeof(Base.convert), T::Type{<:Array},
t::AbstractTensorMap)
A = convert(T, t)
function convert_pullback(ΔA)
# use constructor to (unconditionally) project back onto symmetric subspace
∂t = TensorMap(unthunk(ΔA), codomain(t), domain(t); tol=Inf)
return NoTangent(), NoTangent(), ∂t
end
return A, convert_pullback
end

function ChainRulesCore.rrule(::typeof(Base.convert), ::Type{Dict}, t::AbstractTensorMap)
out = convert(Dict, t)
function convert_pullback(c′)
c = unthunk(c′)
if haskey(c, :data) # :data is the only thing for which this dual makes sense
dual = copy(out)
dual[:data] = c[:data]
return (NoTangent(), NoTangent(), convert(TensorMap, dual))
else
# instead of zero(t) you can also return ZeroTangent(), which is type unstable
return (NoTangent(), NoTangent(), zero(t))
end
end
return out, convert_pullback
end
function ChainRulesCore.rrule(::typeof(Base.convert), ::Type{TensorMap},
t::Dict{Symbol,Any})
return convert(TensorMap, t), v -> (NoTangent(), NoTangent(), convert(Dict, v))
end
Original file line number Diff line number Diff line change
@@ -1,172 +1,5 @@
module TensorKitChainRulesCoreExt

using TensorOperations
using TensorKit
using ChainRulesCore
using LinearAlgebra
using TupleTools

# Utility
# -------

_conj(conjA::Symbol) = conjA == :C ? :N : :C
trivtuple(N) = ntuple(identity, N)

function _repartition(p::IndexTuple, N₁::Int)
length(p) >= N₁ ||
throw(ArgumentError("cannot repartition $(typeof(p)) to $N₁, $(length(p) - N₁)"))
return p[1:N₁], p[(N₁ + 1):end]
end
_repartition(p::Index2Tuple, N₁::Int) = _repartition(linearize(p), N₁)
function _repartition(p::Union{IndexTuple,Index2Tuple}, ::Index2Tuple{N₁}) where {N₁}
return _repartition(p, N₁)
end
function _repartition(p::Union{IndexTuple,Index2Tuple},
t::AbstractTensorMap)
return _repartition(p, numout(t))
end

TensorKit.block(t::ZeroTangent, c::Sector) = t

# Constructors
# ------------

@non_differentiable TensorKit.TensorMap(f::Function, storagetype, cod, dom)
@non_differentiable TensorKit.isomorphism(args...)
@non_differentiable TensorKit.isometry(args...)
@non_differentiable TensorKit.unitary(args...)

function ChainRulesCore.rrule(::Type{<:TensorMap}, d::DenseArray, args...)
function TensorMap_pullback(Δt)
∂d = convert(Array, Δt)
return NoTangent(), ∂d, fill(NoTangent(), length(args))...
end
return TensorMap(d, args...), TensorMap_pullback
end

function ChainRulesCore.rrule(::typeof(convert), T::Type{<:Array}, t::AbstractTensorMap)
A = convert(T, t)
function convert_pullback(ΔA)
∂t = TensorMap(ΔA, codomain(t), domain(t))
return NoTangent(), NoTangent(), ∂t
end
return A, convert_pullback
end

function ChainRulesCore.rrule(::typeof(Base.copy), t::AbstractTensorMap)
copy_pullback(Δt) = NoTangent(), Δt
return copy(t), copy_pullback
end

ChainRulesCore.ProjectTo(::T) where {T<:AbstractTensorMap} = ProjectTo{T}()
function (::ProjectTo{T1})(x::T2) where {S,N1,N2,T1<:AbstractTensorMap{<:Any,S,N1,N2},
T2<:AbstractTensorMap{<:Any,S,N1,N2}}
T1 === T2 && return x
y = similar(x, scalartype(T1))
for (c, b) in blocks(y)
p = ProjectTo(b)
b .= p(block(x, c))
end
return y
end

# Base Linear Algebra
# -------------------

function ChainRulesCore.rrule(::typeof(+), a::AbstractTensorMap, b::AbstractTensorMap)
plus_pullback(Δc) = NoTangent(), Δc, Δc
return a + b, plus_pullback
end

ChainRulesCore.rrule(::typeof(-), a::AbstractTensorMap) = -a, Δc -> (NoTangent(), -Δc)
function ChainRulesCore.rrule(::typeof(-), a::AbstractTensorMap, b::AbstractTensorMap)
minus_pullback(Δc) = NoTangent(), Δc, -Δc
return a - b, minus_pullback
end

function ChainRulesCore.rrule(::typeof(*), a::AbstractTensorMap, b::AbstractTensorMap)
times_pullback(Δc) = NoTangent(), @thunk(Δc * b'), @thunk(a' * Δc)
return a * b, times_pullback
end

function ChainRulesCore.rrule(::typeof(*), a::AbstractTensorMap, b::Number)
times_pullback(Δc) = NoTangent(), @thunk(Δc * b'), @thunk(dot(a, Δc))
return a * b, times_pullback
end

function ChainRulesCore.rrule(::typeof(*), a::Number, b::AbstractTensorMap)
times_pullback(Δc) = NoTangent(), @thunk(dot(b, Δc)), @thunk(a' * Δc)
return a * b, times_pullback
end

function ChainRulesCore.rrule(::typeof(), A::AbstractTensorMap, B::AbstractTensorMap)
C = A B
projectA = ProjectTo(A)
projectB = ProjectTo(B)
function otimes_pullback(ΔC_)
ΔC = unthunk(ΔC_)
pΔC = ((codomainind(A)..., (domainind(A) .+ numout(B))...),
((codomainind(B) .+ numout(A))...,
(domainind(B) .+ (numin(A) + numout(A)))...))
dA_ = @thunk begin
ipA = (codomainind(A), domainind(A))
pB = (allind(B), ())
dA = zerovector(A,
TensorOperations.promote_contract(scalartype(ΔC),
scalartype(B)))
dA = tensorcontract!(dA, ipA, ΔC, pΔC, :N, B, pB, :C)
return projectA(dA)
end
dB_ = @thunk begin
ipB = (codomainind(B), domainind(B))
pA = ((), allind(A))
dB = zerovector(B,
TensorOperations.promote_contract(scalartype(ΔC),
scalartype(A)))
dB = tensorcontract!(dB, ipB, A, pA, :C, ΔC, pΔC, :N)
return projectB(dB)
end
return NoTangent(), dA_, dB_
end
return C, otimes_pullback
end

function ChainRulesCore.rrule(::typeof(permute), tsrc::AbstractTensorMap, p::Index2Tuple;
copy::Bool=false)
function permute_pullback(Δtdst)
invp = TensorKit._canonicalize(TupleTools.invperm(linearize(p)), tsrc)
return NoTangent(), permute(unthunk(Δtdst), invp; copy=true), NoTangent()
end
return permute(tsrc, p; copy=true), permute_pullback
end

# LinearAlgebra
# -------------

function ChainRulesCore.rrule(::typeof(tr), A::AbstractTensorMap)
tr_pullback(Δtr) = NoTangent(), Δtr * id(domain(A))
return tr(A), tr_pullback
end

function ChainRulesCore.rrule(::typeof(adjoint), A::AbstractTensorMap)
adjoint_pullback(Δadjoint) = NoTangent(), adjoint(unthunk(Δadjoint))
return adjoint(A), adjoint_pullback
end

function ChainRulesCore.rrule(::typeof(dot), a::AbstractTensorMap, b::AbstractTensorMap)
dot_pullback(Δd) = NoTangent(), @thunk(b * Δd'), @thunk(a * Δd)
return dot(a, b), dot_pullback
end

function ChainRulesCore.rrule(::typeof(norm), a::AbstractTensorMap, p::Real=2)
p == 2 || error("currently only implemented for p = 2")
n = norm(a, p)
norm_pullback(Δn) = NoTangent(), a * (Δn' + Δn) / (n * 2), NoTangent()
return n, norm_pullback
end

# Factorizations
# --------------
# Factorizations rules
# --------------------
function ChainRulesCore.rrule(::typeof(TensorKit.tsvd!), t::AbstractTensorMap;
trunc::TensorKit.TruncationScheme=TensorKit.NoTruncation(),
p::Real=2,
Expand Down Expand Up @@ -211,6 +44,20 @@ function ChainRulesCore.rrule(::typeof(TensorKit.tsvd!), t::AbstractTensorMap;
return (U′, Σ′, V′, ϵ), tsvd!_pullback
end

function ChainRulesCore.rrule(::typeof(LinearAlgebra.svdvals!), t::AbstractTensorMap)
U, S, V = tsvd(t)
s = diag(S)
project_t = ProjectTo(t)

function svdvals_pullback(Δs′)
Δs = unthunk(Δs′)
ΔS = diagm(codomain(S), domain(S), Δs)
return NoTangent(), project_t(U * ΔS * V)
end

return s, svdvals_pullback
end

function ChainRulesCore.rrule(::typeof(TensorKit.eig!), t::AbstractTensorMap; kwargs...)
D, V = eig(t; kwargs...)

Expand Down Expand Up @@ -253,6 +100,21 @@ function ChainRulesCore.rrule(::typeof(TensorKit.eigh!), t::AbstractTensorMap; k
return (D, V), eigh!_pullback
end

function ChainRulesCore.rrule(::typeof(LinearAlgebra.eigvals!), t::AbstractTensorMap;
sortby=nothing, kwargs...)
@assert sortby === nothing "only `sortby=nothing` is supported"
(D, _), eig_pullback = rrule(TensorKit.eig!, t; kwargs...)
d = diag(D)
project_t = ProjectTo(t)
function eigvals_pullback(Δd′)
Δd = unthunk(Δd′)
ΔD = diagm(codomain(D), domain(D), Δd)
return NoTangent(), project_t(eig_pullback((ΔD, ZeroTangent()))[2])
end

return d, eigvals_pullback
end

function ChainRulesCore.rrule(::typeof(leftorth!), t::AbstractTensorMap; alg=QRpos())
alg isa TensorKit.QR || alg isa TensorKit.QRpos ||
error("only `alg=QR()` and `alg=QRpos()` are supported")
Expand Down Expand Up @@ -526,9 +388,8 @@ function qr_pullback!(ΔA::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix,
atol::Real=0,
rtol::Real=atol > 0 ? 0 : eps(real(eltype(R)))^(3 / 4))
Rd = view(R, diagind(R))
p = let tol = atol > 0 ? atol : rtol * maximum(abs, Rd)
findlast(x -> abs(x) >= tol, Rd)
end
tol = atol > 0 ? atol : rtol * maximum(abs, Rd)
p = findlast(>=(tol) abs, Rd)
m, n = size(R)

Q1 = view(Q, :, 1:p)
Expand All @@ -538,7 +399,6 @@ function qr_pullback!(ΔA::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix,
ΔA1 = view(ΔA, :, 1:p)
ΔQ1 = view(ΔQ, :, 1:p)
ΔR1 = view(ΔR, 1:p, :)
ΔR11 = view(ΔR, 1:p, 1:p)

M = similar(R, (p, p))
ΔR isa AbstractZero || mul!(M, ΔR1, R1')
Expand Down Expand Up @@ -581,9 +441,8 @@ function lq_pullback!(ΔA::AbstractMatrix, L::AbstractMatrix, Q::AbstractMatrix,
atol::Real=0,
rtol::Real=atol > 0 ? 0 : eps(real(eltype(L)))^(3 / 4))
Ld = view(L, diagind(L))
p = let tol = atol > 0 ? atol : rtol * maximum(abs, Ld)
findlast(x -> abs(x) >= tol, Ld)
end
tol = atol > 0 ? atol : rtol * maximum(abs, Ld)
p = findlast(>=(tol) abs, Ld)
m, n = size(L)

L1 = view(L, :, 1:p)
Expand All @@ -593,7 +452,6 @@ function lq_pullback!(ΔA::AbstractMatrix, L::AbstractMatrix, Q::AbstractMatrix,
ΔA1 = view(ΔA, 1:p, :)
ΔQ1 = view(ΔQ, 1:p, :)
ΔL1 = view(ΔL, :, 1:p)
ΔR11 = view(ΔL, 1:p, 1:p)

M = similar(L, (p, p))
ΔL isa AbstractZero || mul!(M, L1', ΔL1)
Expand Down Expand Up @@ -631,26 +489,3 @@ function lq_pullback!(ΔA::AbstractMatrix, L::AbstractMatrix, Q::AbstractMatrix,
ldiv!(LowerTriangular(L11)', ΔA1)
return ΔA
end

# Convert rrules
#----------------
function ChainRulesCore.rrule(::typeof(Base.convert), ::Type{Dict}, t::AbstractTensorMap)
out = convert(Dict, t)
function convert_pullback(c)
if haskey(c, :data) # :data is the only thing for which this dual makes sense
dual = copy(out)
dual[:data] = c[:data]
return (NoTangent(), NoTangent(), convert(TensorMap, dual))
else
# instead of zero(t) you can also return ZeroTangent(), which is type unstable
return (NoTangent(), NoTangent(), zero(t))
end
end
return out, convert_pullback
end
function ChainRulesCore.rrule(::typeof(Base.convert), ::Type{TensorMap},
t::Dict{Symbol,Any})
return convert(TensorMap, t), v -> (NoTangent(), NoTangent(), convert(Dict, v))
end

end
Loading

0 comments on commit 61ddbe8

Please sign in to comment.