Skip to content

Commit

Permalink
Merge pull request #73 from JuliaDiff/realsubtype
Browse files Browse the repository at this point in the history
[RFC] Make ForwardDiffNumbers subtypes of Real
  • Loading branch information
jrevels committed Dec 7, 2015
2 parents 2f50510 + dd74a42 commit 32ae0d8
Show file tree
Hide file tree
Showing 8 changed files with 107 additions and 56 deletions.
2 changes: 1 addition & 1 deletion docs/source/how_it_works.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ As previously stated, ForwardDiff.jl is an implementation of `forward mode autom
New Number Types
----------------

ForwardDiff.jl provides several new number types, which are all subtypes of the abstract type ``ForwardDiffNumber{N,T,C} <: Number``. These number types store both normal values, and the values of partial derivatives.
ForwardDiff.jl provides several new number types, which are all subtypes of the abstract type ``ForwardDiffNumber{N,T,C} <: Real``. These number types store both normal values, and the values of partial derivatives.

Elementary numerical functions on these types are overloaded to evaluate both the original function, *and* evaluate partials derivatives of the function, storing the results in a ``ForwardDiffNumber``. We can then pass these number types into a general function :math:`f` (which is assumed to be composed of the overloaded elementary functions), and the derivative information is naturally propogated at each step of the calculation by way of the chain rule.

Expand Down
6 changes: 4 additions & 2 deletions docs/source/perf_diff.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@ Restrictions on the target function

ForwardDiff.jl can only differentiate functions that adhere to the following rules:

- **The function can only be composed of generic Julia functions.** ForwardDiff cannot propagate derivative information through non-Julia code. Thus, your function may not work if it makes calls to external, non-Julia programs, e.g. uses explicit BLAS calls instead of ``Ax_mul_Bx``-style functions.

- **The function must be unary (i.e., only accept a single argument).** The ``jacobian`` function is the exception to this restriction; see below for details.

- **The function must accept an argument whose type is a subtype of** ``Vector`` **or** ``Real``. The argument type does not need to be annotated in the function definition.

- **The function's argument type cannot be too restrictively annotated.** In this case, "too restrictive" means more restrictive than ``x::Vector`` or ``x::Number``.
- **The function's argument type cannot be too restrictively annotated.** In this case, "too restrictive" means more restrictive than ``x::Vector`` or ``x::Real``.

- **All number types involved in the function must be subtypes of** ``Real``. We believe extension to subtypes of ``Complex`` is possible, but it hasn't yet been worked on.
- **All number types involved in the function must be subtypes of** ``Real``. We believe extension to subtypes of ``Complex`` is possible, but it hasn't yet been worked on. Note that custom (i.e. non-Base) subtypes of `Real` are not supported.

- **The function must be** `type-stable`_ **.** This is not a strict limitation in every case, but in some cases, lack of type-stability can cause errors. At the very least, type-instablity can severely hinder performance.

Expand Down
81 changes: 73 additions & 8 deletions src/ForwardDiffNumber.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
abstract ForwardDiffNumber{N,T<:Number,C} <: Number
###############################
# Abstract Types/Type Aliases #
###############################

@eval typealias ExternalReal Union{$(subtypes(Real)...)}
abstract ForwardDiffNumber{N,T<:Real,C} <: Real

# Subtypes F<:ForwardDiffNumber should define:
# npartials(::Type{F}) --> N from ForwardDiffNumber{N,T,C}
Expand All @@ -25,6 +30,55 @@ abstract ForwardDiffNumber{N,T<:Number,C} <: Number
# write(io::IO, n::F)
# conversion/promotion rules

####################
# Type Definitions #
####################

immutable GradientNumber{N,T,C} <: ForwardDiffNumber{N,T,C}
value::T
partials::Partials{T,C}
end

immutable HessianNumber{N,T,C} <: ForwardDiffNumber{N,T,C}
gradnum::GradientNumber{N,T,C}
hess::Vector{T}
function HessianNumber(gradnum, hess)
@assert length(hess) == halfhesslen(N)
return new(gradnum, hess)
end
end

immutable TensorNumber{N,T,C} <: ForwardDiffNumber{N,T,C}
hessnum::HessianNumber{N,T,C}
tens::Vector{T}
function TensorNumber(hessnum, tens)
@assert length(tens) == halftenslen(N)
return new(hessnum, tens)
end
end


##################################
# Ambiguous Function Definitions #
##################################
ambiguous_error{A,B}(f, a::A, b::B) = error("""Oops - $f(::$A, ::$B) should never have been called.
It was defined to resolve ambiguity, and was supposed to
fall back to a more specific method defined elsewhere.
Please report this bug to ForwardDiff.jl's issue tracker.""")

ambiguous_binary_funcs = [:(==), :isequal, :isless, :+, :-, :*, :/, :^, :atan2, :calc_atan2]
fdnum_ambiguous_binary_funcs = [:(==), :isequal, :isless]
fdnum_types = [:GradientNumber, :HessianNumber, :TensorNumber]

for f in ambiguous_binary_funcs
if f in fdnum_ambiguous_binary_funcs
@eval $f(a::ForwardDiffNumber, b::ForwardDiffNumber) = ambiguous_error($f, a, b)
end
for A in fdnum_types, B in fdnum_types
@eval $f(a::$A, b::$B) = ambiguous_error($f, a, b)
end
end

##############################
# Utility/Accessor Functions #
##############################
Expand All @@ -43,15 +97,22 @@ abstract ForwardDiffNumber{N,T<:Number,C} <: Number
@inline eltype{N,T,C}(::Type{ForwardDiffNumber{N,T,C}}) = T
@inline containtype{N,T,C}(::Type{ForwardDiffNumber{N,T,C}}) = C

==(n::ForwardDiffNumber, x::Real) = isconstant(n) && (value(n) == x)
==(x::Real, n::ForwardDiffNumber) = ==(n, x)
for T in fdnum_types
@eval isless(a::$T, b::$T) = value(a) < value(b)
end

isequal(n::ForwardDiffNumber, x::Real) = isconstant(n) && isequal(value(n), x)
isequal(x::Real, n::ForwardDiffNumber) = isequal(n, x)
for T in (Base.Irrational, AbstractFloat, Real)
@eval begin
==(n::ForwardDiffNumber, x::$T) = isconstant(n) && (value(n) == x)
==(x::$T, n::ForwardDiffNumber) = ==(n, x)

isless(a::ForwardDiffNumber, b::ForwardDiffNumber) = value(a) < value(b)
isless(x::Real, n::ForwardDiffNumber) = x < value(n)
isless(n::ForwardDiffNumber, x::Real) = value(n) < x
isequal(n::ForwardDiffNumber, x::$T) = isconstant(n) && isequal(value(n), x)
isequal(x::$T, n::ForwardDiffNumber) = isequal(n, x)

isless(x::$T, n::ForwardDiffNumber) = x < value(n)
isless(n::ForwardDiffNumber, x::$T) = value(n) < x
end
end

copy(n::ForwardDiffNumber) = n # assumes all types of ForwardDiffNumbers are immutable

Expand All @@ -74,6 +135,10 @@ function float(n::ForwardDiffNumber)
return convert(switch_eltype(typeof(n), T), n)
end

convert(::Type{Integer}, n::ForwardDiffNumber) = isconstant(n) ? Integer(value(n)) : throw(InexactError())
convert(::Type{Bool}, n::ForwardDiffNumber) = isconstant(n) ? Bool(value(n)) : throw(InexactError())
convert{T<:ExternalReal}(::Type{T}, n::ForwardDiffNumber) = isconstant(n) ? T(value(n)) : throw(InexactError())

##################
# Math Functions #
##################
Expand Down
14 changes: 5 additions & 9 deletions src/GradientNumber.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
immutable GradientNumber{N,T,C} <: ForwardDiffNumber{N,T,C}
value::T
partials::Partials{T,C}
end

################
# Constructors #
################
GradientNumber{N,T}(value::T, grad::NTuple{N,T}) = GradientNumber{N,T,NTuple{N,T}}(value, Partials(grad))
GradientNumber{T}(value::T, grad::T...) = GradientNumber(value, grad)

Expand Down Expand Up @@ -53,10 +51,8 @@ end
convert{N,T,C}(::Type{GradientNumber{N,T,C}}, g::GradientNumber) = GradientNumber{N,T,C}(value(g), partials(g))
convert{N,T,C}(::Type{GradientNumber{N,T,C}}, g::GradientNumber{N,T,C}) = g
convert(::Type{GradientNumber}, g::GradientNumber) = g

convert{T<:Real}(::Type{T}, g::GradientNumber) = isconstant(g) ? T(value(g)) : throw(InexactError())
convert{N,T,C}(::Type{GradientNumber{N,T,C}}, x::Real) = GradientNumber{N,T,C}(x, zero_partials(C, N))
convert(::Type{GradientNumber}, x::Real) = GradientNumber(x)
convert{N,T,C}(::Type{GradientNumber{N,T,C}}, x::ExternalReal) = GradientNumber{N,T,C}(x, zero_partials(C, N))
convert(::Type{GradientNumber}, x::ExternalReal) = GradientNumber(x)

############################
# Math with GradientNumber #
Expand Down
16 changes: 4 additions & 12 deletions src/HessianNumber.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,6 @@
immutable HessianNumber{N,T,C} <: ForwardDiffNumber{N,T,C}
gradnum::GradientNumber{N,T,C}
hess::Vector{T}
function HessianNumber(gradnum, hess)
@assert length(hess) == halfhesslen(N)
return new(gradnum, hess)
end
end

################
# Constructors #
################
function HessianNumber{N,T,C}(gradnum::GradientNumber{N,T,C},
hess::Vector=zeros(T, halfhesslen(N)))
return HessianNumber{N,T,C}(gradnum, hess)
Expand Down Expand Up @@ -63,9 +57,7 @@ end
##############
# Conversion #
##############
convert{N,T,C}(::Type{HessianNumber{N,T,C}}, x::Real) = HessianNumber(GradientNumber{N,T,C}(x))
convert{T<:Real}(::Type{T}, h::HessianNumber) = isconstant(h) ? T(value(h)) : throw(InexactError())

convert{N,T,C}(::Type{HessianNumber{N,T,C}}, x::ExternalReal) = HessianNumber(GradientNumber{N,T,C}(x))
convert{N,T,C}(::Type{HessianNumber{N,T,C}}, h::HessianNumber{N}) = HessianNumber(GradientNumber{N,T,C}(gradnum(h)), hess(h))
convert{N,T,C}(::Type{HessianNumber{N,T,C}}, h::HessianNumber{N,T,C}) = h
convert(::Type{HessianNumber}, h::HessianNumber) = h
Expand Down
8 changes: 6 additions & 2 deletions src/Partials.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ done(partials, i) = done(data(partials), i)
################
# Constructors #
################
@inline zero_partials{N,T}(::Type{NTuple{N,T}}, n::Int) = Partials(zero_tuple(NTuple{N,T}))
@inline zero_partials{C<:Tuple}(::Type{C}, n::Int) = Partials(zero_tuple(C))
zero_partials{T}(::Type{Vector{T}}, n) = Partials(zeros(T, n))

@inline rand_partials{N,T}(::Type{NTuple{N,T}}, n::Int) = Partials(rand_tuple(NTuple{N,T}))
@inline rand_partials{C<:Tuple}(::Type{C}, n::Int) = Partials(rand_tuple(C))
rand_partials{T}(::Type{Vector{T}}, n::Int) = Partials(rand(T, n))

#####################
Expand Down Expand Up @@ -171,6 +171,8 @@ function tupexpr(f,N)
end
end

@inline zero_tuple(::Type{Tuple{}}) = tuple()

@generated function zero_tuple{N,T}(::Type{NTuple{N,T}})
result = tupexpr(i -> :z, N)
return quote
Expand All @@ -179,6 +181,8 @@ end
end
end

@inline rand_tuple(::Type{Tuple{}}) = tuple()

@generated function rand_tuple{N,T}(::Type{NTuple{N,T}})
return tupexpr(i -> :(rand($T)), N)
end
Expand Down
16 changes: 4 additions & 12 deletions src/TensorNumber.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,6 @@
immutable TensorNumber{N,T,C} <: ForwardDiffNumber{N,T,C}
hessnum::HessianNumber{N,T,C}
tens::Vector{T}
function TensorNumber(hessnum, tens)
@assert length(tens) == halftenslen(N)
return new(hessnum, tens)
end
end

################
# Constructors #
################
function TensorNumber{N,T,C}(hessnum::HessianNumber{N,T,C},
tens::Vector=zeros(T, halftenslen(N)))
return TensorNumber{N,T,C}(hessnum, tens)
Expand Down Expand Up @@ -64,9 +58,7 @@ end
########################
# Conversion/Promotion #
########################
convert{N,T,C}(::Type{TensorNumber{N,T,C}}, x::Real) = TensorNumber(HessianNumber{N,T,C}(x))
convert{T<:Real}(::Type{T}, t::TensorNumber) = isconstant(t) ? T(value(t)) : throw(InexactError())

convert{N,T,C}(::Type{TensorNumber{N,T,C}}, x::ExternalReal) = TensorNumber(HessianNumber{N,T,C}(x))
convert{N,T,C}(::Type{TensorNumber{N,T,C}}, t::TensorNumber{N}) = TensorNumber(HessianNumber{N,T,C}(hessnum(t)), tens(t))
convert{N,T,C}(::Type{TensorNumber{N,T,C}}, t::TensorNumber{N,T,C}) = t
convert(::Type{TensorNumber}, t::TensorNumber) = t
Expand Down
20 changes: 10 additions & 10 deletions test/test_gradients.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using Base.Test
using Calculus
using ForwardDiff
using ForwardDiff:
using ForwardDiff:
GradientNumber,
value,
grad,
Expand Down Expand Up @@ -31,7 +31,7 @@ for (test_partials, Grad) in ((test_partialstup, ForwardDiff.GradNumTup), (test_
end

@test npartials(test_grad) == npartials(typeof(test_grad)) == N

##################################
# Value Representation Functions #
##################################
Expand Down Expand Up @@ -92,7 +92,7 @@ for (test_partials, Grad) in ((test_partialstup, ForwardDiff.GradNumTup), (test_
@test isnan(Grad{3,T}(NaN))

not_const_grad = Grad{N,T}(one(T), map(one, test_partials))
@test !(isconstant(not_const_grad))
@test !(isconstant(not_const_grad))
@test !(isreal(not_const_grad))
@test isconstant(const_grad) && isreal(const_grad)
@test isconstant(zero(not_const_grad)) && isreal(zero(not_const_grad))
Expand All @@ -116,7 +116,7 @@ for (test_partials, Grad) in ((test_partialstup, ForwardDiff.GradNumTup), (test_
seekstart(io)

@test read(io, typeof(test_grad)) == test_grad

close(io)

#####################################
Expand All @@ -136,7 +136,7 @@ for (test_partials, Grad) in ((test_partialstup, ForwardDiff.GradNumTup), (test_
@test rand_val + test_grad == test_grad + rand_val
@test rand_val - test_grad == Grad{N,T}(rand_val-test_val, map(-, test_partials))
@test test_grad - rand_val == Grad{N,T}(test_val-rand_val, test_partials)

@test -test_grad == Grad{N,T}(-test_val, map(-, test_partials))

# Multiplication #
Expand Down Expand Up @@ -196,12 +196,12 @@ for (test_partials, Grad) in ((test_partialstup, ForwardDiff.GradNumTup), (test_
end

x = value(orig_grad)
df = $expr
df = $expr

@test_approx_eq value(f_grad) func(x)

for i in 1:N
try
try
@test_approx_eq grad(f_grad, i) df*grad(orig_grad, i)
catch exception
info("The exception was thrown while testing function $func at value $orig_grad")
Expand Down Expand Up @@ -260,13 +260,13 @@ end
chunk_sizes = (ForwardDiff.default_chunk_size, 1, Int(N/2), N)

for fsym in map(first, Calculus.symbolic_derivatives_1arg())
testexpr = :($(fsym)(a) + $(fsym)(b) - $(fsym)(c) * $(fsym)(d))
testexpr = :($(fsym)(a) + $(fsym)(b) - $(fsym)(c) * $(fsym)(d))

@eval function testf(x::Vector)
@eval function testf(x::Vector)
a,b,c,d = x
return $testexpr
end

for chunk in chunk_sizes
try
testx = grad_test_x(fsym, N)
Expand Down

0 comments on commit 32ae0d8

Please sign in to comment.