Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use a functor for projection #385

Merged
merged 44 commits into from
Jul 6, 2021
Merged
Show file tree
Hide file tree
Changes from 34 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
914bd92
Sketch project implementation
willtebbutt Feb 24, 2021
06678a4
change Composite to Tangent
Jun 22, 2021
c58f974
export project
Jun 22, 2021
00020e3
make T optional
Jun 22, 2021
37f9253
add tests and Complex
Jun 22, 2021
4e1b79d
workout the edge cases
Jun 22, 2021
7dc58ee
rename dummy struct
Jun 22, 2021
3345ba9
rename project to projector
Jun 23, 2021
31d81ed
move to projector
Jun 24, 2021
2ea4845
do not close over x (other than in the general case)
Jun 24, 2021
465e1d7
update docstring
Jun 24, 2021
0a06dce
fix getproperty
Jun 24, 2021
d822b02
add to Tangent and to Symmetric
Jun 24, 2021
25a7cee
remove debug strings
Jun 24, 2021
7801e19
separate out the projector
Jun 24, 2021
9147fad
implement preproject
Jun 25, 2021
cc2f199
remove getproperty for thunks
Jun 25, 2021
2aa3859
remove to Tangent
Jun 25, 2021
44ef266
fix docstrings
Jun 25, 2021
d8848f5
project nested structs
Jun 25, 2021
88da9c6
Change from preproject to ProjectTo functor
oxinabox Jun 29, 2021
e0318b3
Make sure Arrays of Arrays etc work
oxinabox Jun 29, 2021
ce5d646
remove the special case ProjectTo(::Type{<:Number})
Jun 30, 2021
12a0db4
Merge branch 'master' into mz/projectto
Jun 30, 2021
f1a6260
add to_ prefix, add Transpose/Adjoint/SubArray
Jun 30, 2021
06268a3
add Adjoint and Transpose test
Jun 30, 2021
a981279
test Tangents with implicit zeros
Jun 30, 2021
eefd84f
throw error when ProjectTo to Tuple or NamedTuple
Jun 30, 2021
2facaea
fix transpose bug
Jun 30, 2021
9787b1b
add test for TwoFields
Jul 1, 2021
93c7489
test complex numbers too
Jul 1, 2021
e7190b2
nested where
Jul 1, 2021
233d292
fix SubArray
Jul 1, 2021
4c25f32
add Hermitian
Jul 2, 2021
029cb69
remove debug statements
Jul 2, 2021
b73e246
add Upper and LowerTriangular
Jul 2, 2021
9d665c0
PermutedDimsArray
Jul 2, 2021
030d636
Update test/projection.jl
mzgubic Jul 2, 2021
b87368f
fix docs
Jul 5, 2021
0f09ab9
JuliaFormatter
Jul 5, 2021
3a47f6f
simplify one of the PermutedDimsArray
Jul 5, 2021
ce022d5
document when to use ProjectTo
Jul 5, 2021
4106232
Apply suggestions from code review
mzgubic Jul 6, 2021
04a4e87
Update docs/Manifest.toml
mzgubic Jul 6, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/ChainRulesCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ export RuleConfig, HasReverseMode, NoReverseMode, HasForwardsMode, NoForwardsMod
export frule_via_ad, rrule_via_ad
# definition helper macros
export @non_differentiable, @scalar_rule, @thunk, @not_implemented
export canonicalize, extern, unthunk # differential operations
export ProjectTo, canonicalize, extern, unthunk # differential operations
export add!! # gradient accumulation operations
# differentials
export Tangent, NoTangent, InplaceableThunk, Thunk, ZeroTangent, AbstractZero, AbstractThunk
Expand All @@ -26,6 +26,7 @@ include("differentials/notimplemented.jl")

include("differential_arithmetic.jl")
include("accumulation.jl")
include("projection.jl")

include("config.jl")
include("rules.jl")
Expand Down
1 change: 1 addition & 0 deletions src/differentials/composite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ backing(x::NamedTuple) = x
backing(x::Dict) = x
backing(x::Tangent) = getfield(x, :backing)

# For generic structs
function backing(x::T)::NamedTuple where T
# note: all computation outside the if @generated happens at runtime.
# so the first 4 lines of the branchs look the same, but can not be moved out.
Expand Down
132 changes: 132 additions & 0 deletions src/projection.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
struct ProjectTo{P, D<:NamedTuple}
info::D
end
ProjectTo{P}(info::D) where {P,D<:NamedTuple} = ProjectTo{P,D}(info)
ProjectTo{P}(; kwargs...) where {P} = ProjectTo{P}(NamedTuple(kwargs))

backing(project::ProjectTo) = getfield(project, :info)
Base.getproperty(p::ProjectTo, name::Symbol) = getproperty(backing(p), name)
Base.propertynames(p::ProjectTo) = propertynames(backing(p))

function Base.show(io::IO, project::ProjectTo{T}) where T
print(io, "ProjectTo{")
show(io, T)
print(io, "}")
if isempty(backing(project))
print(io, "()")
else
show(io, backing(project))
end
end


"""
ProjectTo(x)

Returns a `ProjectTo{P,...}` functor able to project a differential `dx` onto the type `T`
for a primal `x`.
This functor encloses over what ever is needed to be able to be able to do that projection.
For example, when projecting `dx=ZeroTangent()` on an array `P=Array{T, N}`, the size of `x`
is not available from `P`, so it is stored in the functor.
"""
function ProjectTo end

"""
(::ProjectTo{T})(dx)

Projects the differential `dx` on the onto type `T`.
`ProjectTo{T}` is a functor that knows how to perform this projection.
"""
function (::ProjectTo) end
oxinabox marked this conversation as resolved.
Show resolved Hide resolved

# fallback (structs)
function ProjectTo(x::T) where {T}
# Generic fallback for structs, recursively make `ProjectTo`s all their fields
#println()
#@show x
#@show T
Copy link
Member Author

@oxinabox oxinabox Jun 30, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't for get to remove all these debug lines

Suggested change
#println()
#@show x
#@show T

fields_nt::NamedTuple = backing(x)
#@show fields_nt
return ProjectTo{T}(map(ProjectTo, fields_nt))
end
function (project::ProjectTo{T})(dx::Tangent) where {T}
sub_projects = backing(project)
#@show sub_projects
sub_dxs = backing(canonicalize(dx))
#@show sub_dxs
_call(f, x) = f(x)
return construct(T, map(_call, sub_projects, sub_dxs))
mzgubic marked this conversation as resolved.
Show resolved Hide resolved
end

# should not work for Tuples and NamedTuples, as not valid tangent types
function ProjectTo(x::T) where {T<:Union{<:Tuple, NamedTuple}}
throw(ArgumentError("The `x` in `ProjectTo(x)` must be a valid differential, not $x"))
end

# Generic
(project::ProjectTo)(dx::AbstractThunk) = project(unthunk(dx))
(::ProjectTo{T})(dx::T) where {T} = dx # not always true, but we can special case for when it isn't
mzgubic marked this conversation as resolved.
Show resolved Hide resolved
(::ProjectTo{T})(dx::AbstractZero) where {T} = zero(T)

# Number
ProjectTo(::T) where {T<:Number} = ProjectTo{T}()
(::ProjectTo{T})(dx::Number) where {T<:Number} = convert(T, dx)
(::ProjectTo{T})(dx::Number) where {T<:Real} = convert(T, real(dx))

# Arrays
ProjectTo(xs::T) where {T<:Array} = ProjectTo{T}(; elements=map(ProjectTo, xs))
mzgubic marked this conversation as resolved.
Show resolved Hide resolved
function (project::ProjectTo{T})(dx::Array) where {T<:Array}
_call(f, x) = f(x)
return T(map(_call, project.elements, dx))
end
function (project::ProjectTo{T})(dx::AbstractZero) where {T<:Array}
return T(map(proj->proj(dx), project.elements))
end
(project::ProjectTo{<:Array})(dx::AbstractArray) = project(collect(dx))

# Arrays{<:Number}: optimized case so we don't need a projector per element
ProjectTo(x::T) where {E<:Number, T<:Array{E}} = ProjectTo{T}(; element=ProjectTo(zero(E)), size=size(x))
(project::ProjectTo{<:Array{T}})(dx::Array) where {T<:Number} = project.element.(dx)
(project::ProjectTo{<:Array{T}})(dx::AbstractZero) where {T<:Number} = zeros(T, project.size)
(project::ProjectTo{<:Array{T}})(dx::Tangent{<:SubArray}) where {T<:Number} = project(dx.parent)

# Diagonal
ProjectTo(x::T) where {T<:Diagonal} = ProjectTo{T}(; diag=ProjectTo(diag(x)))
(project::ProjectTo{T})(dx::AbstractMatrix) where {T<:Diagonal} = T(project.diag(diag(dx)))
(project::ProjectTo{T})(dx::AbstractZero) where {T<:Diagonal} = T(project.diag(dx))

# Symmetric and Hermitian
for SymHerm = (:Symmetric, :Hermitian)
@eval begin
ProjectTo(x::T) where {T<:$SymHerm} = ProjectTo{T}(; uplo=Symbol(x.uplo), parent=ProjectTo(parent(x)))
(project::ProjectTo{<:$SymHerm})(dx::AbstractMatrix) = $SymHerm(project.parent(dx), project.uplo)
(project::ProjectTo{<:$SymHerm})(dx::AbstractZero) = $SymHerm(project.parent(dx), project.uplo)
(project::ProjectTo{<:$SymHerm})(dx::Tangent) = $SymHerm(project.parent(dx.data), project.uplo)
end
end

# Transpose
ProjectTo(x::T) where {T<:Transpose} = ProjectTo{T}(; parent=ProjectTo(parent(x)))
(project::ProjectTo{<:Transpose})(dx::AbstractMatrix) = transpose(project.parent(transpose(dx)))
(project::ProjectTo{<:Transpose})(dx::Adjoint) = transpose(project.parent(conj(parent(dx))))
(project::ProjectTo{<:Transpose})(dx::AbstractZero) = transpose(project.parent(dx))

# Adjoint
ProjectTo(x::T) where {T<:Adjoint} = ProjectTo{T}(; parent=ProjectTo(parent(x)))
(project::ProjectTo{<:Adjoint})(dx::AbstractMatrix) = adjoint(project.parent(adjoint(dx)))
(project::ProjectTo{<:Adjoint})(dx::ZeroTangent) = adjoint(project.parent(dx))

# SubArray
ProjectTo(x::T) where {T<:SubArray} = ProjectTo(copy(x)) # don't project on to a view, but onto matching copy
mzgubic marked this conversation as resolved.
Show resolved Hide resolved












215 changes: 215 additions & 0 deletions test/projection.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
struct Fred
a::Float64
end
Base.zero(::Fred) = Fred(0.0)
Base.zero(::Type{Fred}) = Fred(0.0)

struct Freddy{T, N}
a::Array{T, N}
end
Base.:(==)(a::Freddy, b::Freddy) = a.a == b.a

struct Mary
a::Fred
end

struct TwoFields
a::Float64
b::Float64
end

@testset "projection" begin
@testset "display" begin
@test startswith(repr(ProjectTo(Fred(1.1))), "ProjectTo{Fred}(")
@test repr(ProjectTo(1.1)) == "ProjectTo{Float64}()"
end

@testset "fallback" begin
@test Fred(1.2) == ProjectTo(Fred(1.1))(Fred(1.2))
@test Fred(0.0) == ProjectTo(Fred(1.1))(ZeroTangent())
@test Fred(3.2) == ProjectTo(Fred(1.1))(@thunk(Fred(3.2)))
@test Fred(1.2) == ProjectTo(Fred(1.1))(Tangent{Fred}(;a=1.2))

# struct with complicated field
x = Freddy(zeros(2,2))
dx = Tangent{Freddy}(; a=ZeroTangent())
@test x == ProjectTo(x)(dx)

# nested structs
f = Fred(0.0)
tf = Tangent{Fred}(;a=ZeroTangent())
m = Mary(f)
dm = Tangent{Mary}(;a=tf)
@test m == ProjectTo(m)(dm)

# two fields
tfa = TwoFields(3.0, 0.0)
tfb = TwoFields(0.0, 3.0)
@test tfa == ProjectTo(tfa)(Tangent{TwoFields}(; a=3.0))
@test tfb == ProjectTo(tfb)(Tangent{TwoFields}(; b=3.0))
end

@testset "to Real" begin
# Float64
@test 3.2 == ProjectTo(1.0)(3.2)
@test 0.0 == ProjectTo(1.0)(ZeroTangent())
@test 3.2 == ProjectTo(1.0)(@thunk(3.2))

# down
@test 3.2 == ProjectTo(1.0)(3.2 + 3im)
@test 3.2f0 == ProjectTo(1.0f0)(3.2)
@test 3.2f0 == ProjectTo(1.0f0)(3.2 - 3im)

# up
@test 2.0 == ProjectTo(1.0)(2.0f0)
end

@testset "to Number" begin
# Complex
@test 2.0 + 4.0im == ProjectTo(1.0im)(2.0 + 4.0im)

# down
@test 2.0 + 0.0im == ProjectTo(1.0im)(2.0)
@test 0.0 + 0.0im == ProjectTo(1.0im)(ZeroTangent())
@test 0.0 + 0.0im == ProjectTo(1.0im)(@thunk(ZeroTangent()))

# up
@test 2.0 + 0.0im == ProjectTo(1.0im)(2.0)
end

@testset "to Array" begin
# to an array of numbers
x = zeros(2, 2)
@test [1.0 2.0; 3.0 4.0] == ProjectTo(x)([1.0 2.0; 3.0 4.0])
@test x == ProjectTo(x)(ZeroTangent())

x = zeros(2)
@test x == ProjectTo(x)(@thunk(ZeroTangent()))

x = zeros(Float32, 2, 2)
@test x == ProjectTo(x)([0.0 0; 0 0])

x = [1.0 0; 0 4]
@test x == ProjectTo(x)(Diagonal([1.0, 4]))

# to a array of structs
x = [Fred(0.0), Fred(0.0)]
@test x == ProjectTo(x)([Fred(0.0), Fred(0.0)])
@test x == ProjectTo(x)([ZeroTangent(), ZeroTangent()])
@test x == ProjectTo(x)([ZeroTangent(), @thunk(Fred(0.0))])
@test x == ProjectTo(x)(ZeroTangent())
@test x == ProjectTo(x)(@thunk(ZeroTangent()))

x = [Fred(1.0) Fred(0.0); Fred(0.0) Fred(4.0)]
@test x == ProjectTo(x)(Diagonal([Fred(1.0), Fred(4.0)]))
end

@testset "To Array of Arrays" begin
# inner arrays have same type but different sizes
x = [[1.0, 2.0, 3.0], [4.0, 5.0]]
@test x == ProjectTo(x)(x)
@test x == ProjectTo(x)([[1.0 + 2im, 2.0, 3.0], [4.0 + 2im, 5.0]])

# This makes sure we don't fall for https://github.com/JuliaLang/julia/issues/38064
@test [[0.0, 0.0, 0.0], [0.0, 0.0]] == ProjectTo(x)(ZeroTangent())
end

@testset "Array{Any} with really messy contents" begin
# inner arrays have same type but different sizes
x = [[1.0, 2.0, 3.0], [4.0+im 5.0], [[[Fred(1)]]]]
@test x == ProjectTo(x)(x)
@test x == ProjectTo(x)([[1.0+im, 2.0, 3.0], [4.0+im 5.0], [[[Fred(1)]]]])
# using a different type for the 2nd element (Adjoint)
@test x == ProjectTo(x)([[1.0+im, 2.0, 3.0], [4.0-im, 5.0]', [[[Fred(1)]]]])

@test [[0.0, 0.0, 0.0], [0.0im 0.0], [[[Fred(0)]]]] == ProjectTo(x)(ZeroTangent())
end

@testset "to Diagonal" begin
d_F64 = Diagonal([0.0, 0.0])
d_F32 = Diagonal([0.0f0, 0.0f0])
d_C64 = Diagonal([0.0 + 0im, 0.0])
d_Fred = Diagonal([Fred(0.0), Fred(0.0)])

# from Matrix
@test d_F64 == ProjectTo(d_F64)(zeros(2, 2))
@test d_F64 == ProjectTo(d_F64)(zeros(Float32, 2, 2))
@test d_F64 == ProjectTo(d_F64)(zeros(ComplexF64, 2, 2))

# from Diagonal of Numbers
@test d_F64 == ProjectTo(d_F64)(d_F64)
@test d_F64 == ProjectTo(d_F64)(d_F32)
@test d_F64 == ProjectTo(d_F64)(d_C64)

# from Diagonal of AbstractTangent
@test d_F64 == ProjectTo(d_F64)(ZeroTangent())
@test d_C64 == ProjectTo(d_C64)(ZeroTangent())
@test d_F64 == ProjectTo(d_F64)(@thunk(ZeroTangent()))
@test d_F64 == ProjectTo(d_F64)(Diagonal([ZeroTangent(), ZeroTangent()]))
@test d_F64 == ProjectTo(d_F64)(Diagonal([ZeroTangent(), @thunk(ZeroTangent())]))

# from Diagonal of structs
@test d_Fred == ProjectTo(d_Fred)(ZeroTangent())
@test d_Fred == ProjectTo(d_Fred)(@thunk(ZeroTangent()))
@test d_Fred == ProjectTo(d_Fred)(Diagonal([ZeroTangent(), ZeroTangent()]))

# from Tangent
@test d_F64 == ProjectTo(d_F64)(Tangent{Diagonal}(;diag=[0.0, 0.0]))
@test d_F64 == ProjectTo(d_F64)(Tangent{Diagonal}(;diag=[0.0f0, 0.0f0]))
@test d_F64 == ProjectTo(d_F64)(Tangent{Diagonal}(;diag=[ZeroTangent(), @thunk(ZeroTangent())]))
end

@testset "to $SymHerm" for SymHerm in (Symmetric, Hermitian)
data = [1.0+1im 2-2im; 3 4]

x = SymHerm(data)
@test x == ProjectTo(x)(data)
@test x == ProjectTo(x)(Tangent{typeof(x)}(; data=data, uplo=NoTangent()))

x = SymHerm(data, :L)
@test x == ProjectTo(x)(data)

data = [1.0-2im 0; 0 4]
x = SymHerm(data)
@test x == ProjectTo(x)(Diagonal([1.0-2im, 4.0]))

data = [0.0+0im 0; 0 0]
x = SymHerm(data)
@test x == ProjectTo(x)(ZeroTangent())
@test x == ProjectTo(x)(@thunk(ZeroTangent()))
end

@testset "to Transpose" begin
x = rand(ComplexF64, 3, 4)
t = transpose(x)
mt = collect(t)
a = adjoint(x)
ma = collect(a)

@test t == ProjectTo(t)(mt)
@test conj(t) == ProjectTo(t)(ma)
@test zeros(4, 3) == ProjectTo(t)(ZeroTangent())
@test zeros(4, 3) == ProjectTo(t)(Tangent{Transpose}(; parent=ZeroTangent()))
end

@testset "to Adjoint" begin
mzgubic marked this conversation as resolved.
Show resolved Hide resolved
x = rand(ComplexF64, 3, 4)
a = adjoint(x)
ma = collect(a)

@test a == ProjectTo(a)(ma)
@test zeros(4, 3) == ProjectTo(a)(ZeroTangent())
@test zeros(4, 3) == ProjectTo(a)(Tangent{Adjoint}(; parent=ZeroTangent()))
end

@testset "to SubArray" begin
x = rand(3, 4)
sa = view(x, :, 1:2)
m = collect(sa)

# make sure it converts the view to the parent type
@test ProjectTo(sa)(m) isa Matrix
@test ProjectTo(sa)(ZeroTangent())
mzgubic marked this conversation as resolved.
Show resolved Hide resolved
@test ProjectTo(sa)(Tangent{SubArray}(; parent=ZeroTangent())) isa Matrix
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ using Test
end

include("accumulation.jl")
include("projection.jl")

include("rules.jl")
include("rule_definition_tools.jl")
Expand Down