Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement muladd #9840

Merged
merged 2 commits into from
Jan 22, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions base/exports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,7 @@ export
mod1,
modf,
mod2pi,
muladd,
nextfloat,
nextpow,
nextpow2,
Expand Down
2 changes: 2 additions & 0 deletions base/float.jl
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,8 @@ widen(::Type{Float32}) = Float64

fma(x::Float32, y::Float32, z::Float32) = box(Float32,fma_float(unbox(Float32,x),unbox(Float32,y),unbox(Float32,z)))
fma(x::Float64, y::Float64, z::Float64) = box(Float64,fma_float(unbox(Float64,x),unbox(Float64,y),unbox(Float64,z)))
muladd(x::Float32, y::Float32, z::Float32) = box(Float32,muladd_float(unbox(Float32,x),unbox(Float32,y),unbox(Float32,z)))
muladd(x::Float64, y::Float64, z::Float64) = box(Float64,muladd_float(unbox(Float64,x),unbox(Float64,y),unbox(Float64,z)))

# TODO: faster floating point div?
# TODO: faster floating point fld?
Expand Down
3 changes: 3 additions & 0 deletions base/float16.jl
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,9 @@ end
function fma(a::Float16, b::Float16, c::Float16)
float16(fma(float32(a), float32(b), float32(c)))
end
function muladd(a::Float16, b::Float16, c::Float16)
float16(muladd(float32(a), float32(b), float32(c)))
end
for op in (:<,:<=,:isless)
@eval ($op)(a::Float16, b::Float16) = ($op)(float32(a), float32(b))
end
Expand Down
2 changes: 2 additions & 0 deletions base/promotion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ promote_to_super{T<:Number,S<:Number}(::Type{T}, ::Type{S}, ::Type) =
^(x::Number, y::Number) = ^(promote(x,y)...)

fma(x::Number, y::Number, z::Number) = fma(promote(x,y,z)...)
muladd(x::Number, y::Number, z::Number) = muladd(promote(x,y,z)...)

(&)(x::Integer, y::Integer) = (&)(promote(x,y)...)
(|)(x::Integer, y::Integer) = (|)(promote(x,y)...)
Expand Down Expand Up @@ -195,6 +196,7 @@ no_op_err(name, T) = error(name," not defined for ",T)

fma{T<:Number}(x::T, y::T, z::T) = no_op_err("fma", T)
fma(x::Integer, y::Integer, z::Integer) = x*y+z
muladd{T<:Number}(x::T, y::T, z::T) = no_op_err("muladd", T)

(&){T<:Integer}(x::T, y::T) = no_op_err("&", T)
(|){T<:Integer}(x::T, y::T) = no_op_err("|", T)
Expand Down
10 changes: 9 additions & 1 deletion doc/stdlib/math.rst
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,15 @@ Mathematical Operators

Computes ``x*y+z`` without rounding the intermediate result
``x*y``. On some systems this is significantly more expensive than
``x*y+z``.
``x*y+z``. ``fma`` is used to improve accuracy in certain
algorithms. See ``muladd``.

.. function:: muladd(x, y, z)

Combined multiply-add, computes ``x*y+z`` in an efficient manner.
This may on some systems be equivalent to ``x*y+z``, or to
``fma(x,y,z)``. ``muladd`` is used to improve performance. See
``fma``.

.. function:: div(x, y)
÷(x, y)
Expand Down
18 changes: 16 additions & 2 deletions src/intrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ namespace JL_I {
neg_int, add_int, sub_int, mul_int,
sdiv_int, udiv_int, srem_int, urem_int, smod_int,
neg_float, add_float, sub_float, mul_float, div_float, rem_float,
fma_float,
fma_float, muladd_float,
// fast arithmetic
neg_float_fast, add_float_fast, sub_float_fast,
mul_float_fast, div_float_fast, rem_float_fast,
Expand Down Expand Up @@ -986,6 +986,20 @@ static Value *emit_intrinsic(intrinsic f, jl_value_t **args, size_t nargs,
ArrayRef<Type*>(x->getType())),
FP(x), FP(y), FP(z));
}
HANDLE(muladd_float,3)
#ifdef LLVM34
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The API changed for only v3.4? Will this work on v3.3 and versions >= v3.5?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LLVM34 doesn't do what you probably think it does -- it means "LLVM 3.4 or later".

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes, sorry for the noise.

{
assert(y->getType() == x->getType());
assert(z->getType() == y->getType());
return builder.CreateCall3
(Intrinsic::getDeclaration(jl_Module, Intrinsic::fmuladd,
ArrayRef<Type*>(x->getType())),
FP(x), FP(y), FP(z));
}
#else
return math_builder(ctx, true)().
CreateFAdd(builder.CreateFMul(FP(x), FP(y)), FP(z));
#endif

HANDLE(checked_sadd,2)
HANDLE(checked_uadd,2)
Expand Down Expand Up @@ -1323,7 +1337,7 @@ extern "C" void jl_init_intrinsic_functions(void)
ADD_I(sdiv_int); ADD_I(udiv_int); ADD_I(srem_int); ADD_I(urem_int);
ADD_I(smod_int);
ADD_I(neg_float); ADD_I(add_float); ADD_I(sub_float); ADD_I(mul_float);
ADD_I(div_float); ADD_I(rem_float); ADD_I(fma_float);
ADD_I(div_float); ADD_I(rem_float); ADD_I(fma_float); ADD_I(muladd_float);
ADD_I(neg_float_fast); ADD_I(add_float_fast); ADD_I(sub_float_fast);
ADD_I(mul_float_fast); ADD_I(div_float_fast); ADD_I(rem_float_fast);
ADD_I(eq_int); ADD_I(ne_int);
Expand Down
29 changes: 29 additions & 0 deletions test/numbers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,35 @@ let eps = 1//BigInt(2)^200, one_eps = 1+eps,
@test fma(one_eps256, one_eps256, -1) == BigFloat(one_eps * one_eps - 1)
end

# muladd

let eps = 1//BigInt(2)^30, one_eps = 1+eps,
eps64 = float64(eps), one_eps64 = float64(one_eps)
@test eps64 == float64(eps)
@test one_eps64 == float64(one_eps)
@test one_eps64 * one_eps64 - 1 != float64(one_eps * one_eps - 1)
@test isapprox(muladd(one_eps64, one_eps64, -1),
float64(one_eps * one_eps - 1))
end

let eps = 1//BigInt(2)^15, one_eps = 1+eps,
eps32 = float32(eps), one_eps32 = float32(one_eps)
@test eps32 == float32(eps)
@test one_eps32 == float32(one_eps)
@test one_eps32 * one_eps32 - 1 != float32(one_eps * one_eps - 1)
@test isapprox(muladd(one_eps32, one_eps32, -1),
float32(one_eps * one_eps - 1))
end

let eps = 1//BigInt(2)^7, one_eps = 1+eps,
eps16 = float16(float32(eps)), one_eps16 = float16(float32(one_eps))
@test eps16 == float16(float32(eps))
@test one_eps16 == float16(float32(one_eps))
@test one_eps16 * one_eps16 - 1 != float16(float32(one_eps * one_eps - 1))
@test isapprox(muladd(one_eps16, one_eps16, -1),
float16(float32(one_eps * one_eps - 1)))
end

# lexing typemin(Int64)
@test (-9223372036854775808)^1 == -9223372036854775808
@test [1 -1 -9223372036854775808] == [1 -1 typemin(Int64)]
Expand Down