Skip to content

Commit

Permalink
only use accurate powf function
Browse files Browse the repository at this point in the history
The powi intrinsic optimization over calling powf is that it is inaccurate,

When it is equally accurate (e.g. tiny constant powers)
LLVM will already recongnize and optimize any call to a function named `powf`,
and produce the same speedup.

fix #19872
  • Loading branch information
vtjnash committed Jan 6, 2017
1 parent 6eaec18 commit e3344af
Show file tree
Hide file tree
Showing 8 changed files with 25 additions and 55 deletions.
8 changes: 4 additions & 4 deletions base/fastmath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ module FastMath

export @fastmath

import Core.Intrinsics: box, unbox, powi_llvm, sqrt_llvm_fast
import Core.Intrinsics: box, unbox, powf_llvm, sqrt_llvm_fast

const fast_op =
Dict(# basic arithmetic
Expand Down Expand Up @@ -250,9 +250,9 @@ end

# builtins

pow_fast{T<:FloatTypes}(x::T, y::Integer) = pow_fast(x, Int32(y))
pow_fast{T<:FloatTypes}(x::T, y::Int32) =
box(T, Base.powi_llvm(unbox(T,x), unbox(Int32,y)))
pow_fast{T<:FloatTypes}(x::T, y::Integer) = pow_fast(x, convert(T, y))
pow_fast{T<:FloatTypes}(x::T, y::T) =
box(T, Base.powf_llvm(unbox(T,x), unbox(T,y)))

# TODO: Change sqrt_llvm intrinsic to avoid nan checking; add nan
# checking to sqrt in math.jl; remove sqrt_llvm_fast intrinsic
Expand Down
6 changes: 3 additions & 3 deletions base/math.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import Base: log, exp, sin, cos, tan, sinh, cosh, tanh, asin,
significand_mask, significand_bits, exponent_bits, exponent_bias


import Core.Intrinsics: sqrt_llvm, box, unbox, powi_llvm
import Core.Intrinsics: sqrt_llvm, box, unbox, powf_llvm

# non-type specific math functions

Expand Down Expand Up @@ -648,9 +648,9 @@ end
^(x::Float32, y::Float32) = nan_dom_err(ccall((:powf,libm), Float32, (Float32,Float32), x, y), x+y)

^(x::Float64, y::Integer) =
box(Float64, powi_llvm(unbox(Float64,x), unbox(Int32,Int32(y))))
box(Float64, powf_llvm(unbox(Float64,x), unbox(Float64,Float64(y))))
^(x::Float32, y::Integer) =
box(Float32, powi_llvm(unbox(Float32,x), unbox(Int32,Int32(y))))
box(Float32, powf_llvm(unbox(Float32,x), unbox(Float32,Float32(y))))
^(x::Float16, y::Integer) = Float16(Float32(x)^y)

function angle_restrict_symm(theta)
Expand Down
5 changes: 1 addition & 4 deletions src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -389,10 +389,8 @@ static Function *expect_func;
static Function *jldlsym_func;
static Function *jlnewbits_func;
static Function *jltypeassert_func;
#if JL_LLVM_VERSION < 30600
static Function *jlpow_func;
static Function *jlpowf_func;
#endif
//static Function *jlgetnthfield_func;
static Function *jlgetnthfieldchecked_func;
//static Function *jlsetnthfield_func;
Expand Down Expand Up @@ -5946,7 +5944,6 @@ static void init_julia_llvm_env(Module *m)
"jl_gc_diff_total_bytes", m);
add_named_global(diff_gc_total_bytes_func, *jl_gc_diff_total_bytes);

#if JL_LLVM_VERSION < 30600
Type *powf_type[2] = { T_float32, T_float32 };
jlpowf_func = Function::Create(FunctionType::get(T_float32, powf_type, false),
Function::ExternalLinkage,
Expand All @@ -5964,7 +5961,7 @@ static void init_julia_llvm_env(Module *m)
&pow,
#endif
false);
#endif

std::vector<Type*> array_owner_args(0);
array_owner_args.push_back(T_pjlvalue);
jlarray_data_owner_func =
Expand Down
19 changes: 5 additions & 14 deletions src/intrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1452,23 +1452,14 @@ static Value *emit_untyped_intrinsic(intrinsic f, Value *x, Value *y, Value *z,
ArrayRef<Type*>(x->getType())),
x);
}
case powi_llvm: {
case powf_llvm: {
x = FP(x);
y = JL_INT(y);
Type *tx = x->getType(); // TODO: LLVM expects this to be i32
#if JL_LLVM_VERSION >= 30600
Type *ts[1] = { tx };
Value *powi = Intrinsic::getDeclaration(jl_Module, Intrinsic::powi,
ArrayRef<Type*>(ts));
y = FP(y);
Function *powf = (x->getType() == T_float64 ? jlpow_func : jlpowf_func);
#if JL_LLVM_VERSION >= 30700
return builder.CreateCall(powi, {x, y});
return builder.CreateCall(prepare_call(powf), {x, y});
#else
return builder.CreateCall2(powi, x, y);
#endif
#else
// issue #6506
return builder.CreateCall2(prepare_call(tx == T_float64 ? jlpow_func : jlpowf_func),
x, builder.CreateSIToFP(y, tx));
return builder.CreateCall2(prepare_call(powf), x, y);
#endif
}
case sqrt_llvm_fast: {
Expand Down
2 changes: 1 addition & 1 deletion src/intrinsics.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@
ADD_I(trunc_llvm, 1) \
ADD_I(rint_llvm, 1) \
ADD_I(sqrt_llvm, 1) \
ADD_I(powi_llvm, 2) \
ADD_I(powf_llvm, 2) \
ALIAS(sqrt_llvm_fast, sqrt_llvm) \
/* pointer access */ \
ADD_I(pointerref, 3) \
Expand Down
2 changes: 1 addition & 1 deletion src/julia_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -678,7 +678,7 @@ JL_DLLEXPORT jl_value_t *jl_floor_llvm(jl_value_t *a);
JL_DLLEXPORT jl_value_t *jl_trunc_llvm(jl_value_t *a);
JL_DLLEXPORT jl_value_t *jl_rint_llvm(jl_value_t *a);
JL_DLLEXPORT jl_value_t *jl_sqrt_llvm(jl_value_t *a);
JL_DLLEXPORT jl_value_t *jl_powi_llvm(jl_value_t *a, jl_value_t *b);
JL_DLLEXPORT jl_value_t *jl_powf_llvm(jl_value_t *a, jl_value_t *b);
JL_DLLEXPORT jl_value_t *jl_abs_float(jl_value_t *a);
JL_DLLEXPORT jl_value_t *jl_copysign_float(jl_value_t *a, jl_value_t *b);
JL_DLLEXPORT jl_value_t *jl_flipsign_int(jl_value_t *a, jl_value_t *b);
Expand Down
28 changes: 3 additions & 25 deletions src/runtime_intrinsics.c
Original file line number Diff line number Diff line change
Expand Up @@ -947,6 +947,8 @@ bi_iintrinsic_fast(jl_LLVMFlipSign, flipsign, flipsign_int, )
*pr = fp_select(a, sqrt)
#define copysign_float(a, b) \
fp_select2(a, b, copysign)
#define pow_float(a, b) \
fp_select2(a, b, pow)

un_fintrinsic(abs_float,abs_float)
bi_fintrinsic(copysign_float,copysign_float)
Expand All @@ -955,31 +957,7 @@ un_fintrinsic(floor_float,floor_llvm)
un_fintrinsic(trunc_float,trunc_llvm)
un_fintrinsic(rint_float,rint_llvm)
un_fintrinsic(sqrt_float,sqrt_llvm)

JL_DLLEXPORT jl_value_t *jl_powi_llvm(jl_value_t *a, jl_value_t *b)
{
jl_ptls_t ptls = jl_get_ptls_states();
jl_value_t *ty = jl_typeof(a);
if (!jl_is_bitstype(ty))
jl_error("powi_llvm: a is not a bitstype");
if (!jl_is_bitstype(jl_typeof(b)) || jl_datatype_size(jl_typeof(b)) != 4)
jl_error("powi_llvm: b is not a 32-bit bitstype");
int sz = jl_datatype_size(ty);
jl_value_t *newv = jl_gc_alloc(ptls, sz, ty);
void *pa = jl_data_ptr(a), *pr = jl_data_ptr(newv);
switch (sz) {
/* choose the right size c-type operation */
case 4:
*(float*)pr = powf(*(float*)pa, (float)jl_unbox_int32(b));
break;
case 8:
*(double*)pr = pow(*(double*)pa, (double)jl_unbox_int32(b));
break;
default:
jl_error("powi_llvm: runtime floating point intrinsics are not implemented for bit sizes other than 32 and 64");
}
return newv;
}
bi_fintrinsic(pow_float,powf_llvm)

JL_DLLEXPORT jl_value_t *jl_select_value(jl_value_t *isfalse, jl_value_t *a, jl_value_t *b)
{
Expand Down
10 changes: 7 additions & 3 deletions test/math.jl
Original file line number Diff line number Diff line change
Expand Up @@ -959,9 +959,13 @@ end
end

@testset "issue #13748" begin
let A = [1 2; 3 4]; B = [5 6; 7 8]; C = [9 10; 11 12]
@test muladd(A,B,C) == A*B + C
end
end

@testset "issue #19872" begin
f19872(x) = x ^ 3
@test issubnormal(2.0 ^ (-1024))
@test f19872(2.0) === 8.0
@test !issubnormal(0.0)
end

@test Base.Math.f32(complex(1.0,1.0)) == complex(Float32(1.),Float32(1.))
Expand Down

0 comments on commit e3344af

Please sign in to comment.