Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fast atan and atan2 functions. #8388

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
cc3ef99
Fast vectorizable atan and atan2 functions.
mcourteaux Aug 10, 2024
56cab26
Default to not using fast atan versions if on CUDA.
mcourteaux Aug 10, 2024
59e6d35
Finished fast atan/atan2 functions and tests.
mcourteaux Aug 10, 2024
7c64aa2
Correct attribution.
mcourteaux Aug 10, 2024
020e966
Clang-format
mcourteaux Aug 10, 2024
cde21ca
Weird WebAssembly limits...
mcourteaux Aug 11, 2024
0881fdd
Small improvements to the optimization script.
mcourteaux Aug 11, 2024
73816c6
Polynomial optimization for log, exp, sin, cos with correct ranges.
mcourteaux Aug 11, 2024
a75d68b
Improve fast atan performance tests for GPU.
mcourteaux Aug 12, 2024
1c0f794
Bugfix fast_atan approximation. Fix correctness test to exceed the ra…
mcourteaux Aug 12, 2024
5d32551
Cleanup
mcourteaux Aug 12, 2024
11c5209
Enum class instead of enum for ApproximationPrecision.
mcourteaux Aug 12, 2024
f3b9d8f
Weird Metal limits. There should be a better way...
mcourteaux Aug 12, 2024
1a92308
Skip test for WebGPU.
mcourteaux Aug 12, 2024
775061a
Fast atan/atan2 polynomials reoptimized. New optimization strategy: ULP.
mcourteaux Aug 13, 2024
6cb4fac
Feedback Steven.
mcourteaux Aug 13, 2024
e9823c1
More comments and test mantissa error.
mcourteaux Aug 14, 2024
3ced523
Do not error when testing arctan performance on Metal / WebGPU.
mcourteaux Aug 14, 2024
b35f7fa
Rework precision specification. Generalize towards using this for oth…
mcourteaux Nov 11, 2024
c004f72
Clang-format.
mcourteaux Nov 11, 2024
34a5ff9
Fix makefile and clang-tidy.
mcourteaux Nov 11, 2024
9408c93
Fix incorrect approximation selection when required precision is not …
mcourteaux Nov 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,7 @@ SOURCE_FILES = \
AlignLoads.cpp \
AllocationBoundsInference.cpp \
ApplySplit.cpp \
ApproximationTables.cpp \
Argument.cpp \
AssociativeOpsTable.cpp \
Associativity.cpp \
Expand Down
111 changes: 111 additions & 0 deletions src/ApproximationTables.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
#include "ApproximationTables.h"

namespace Halide {
namespace Internal {

namespace {

using OO = ApproximationPrecision::OptimizationObjective;

// Generate this table with:
// python3 src/polynomial_optimizer.py atan --order 1 2 3 4 5 6 7 8 --loss mse mae mulpe mulpe_mae --no-gui --format table
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may need to be wrapped with // clang-format:off to avoid getting mangled

std::vector<Approximation> table_atan = {
{OO::MSE, 9.249650e-04, 7.078984e-02, 2.411e+06, {+8.56188008e-01}},
{OO::MSE, 1.026356e-05, 9.214909e-03, 3.985e+05, {+9.76213454e-01, -2.00030200e-01}},
{OO::MSE, 1.577588e-07, 1.323851e-03, 6.724e+04, {+9.95982073e-01, -2.92278128e-01, +8.30180680e-02}},
{OO::MSE, 2.849011e-09, 1.992218e-04, 1.142e+04, {+9.99316541e-01, -3.22286501e-01, +1.49032461e-01, -4.08635592e-02}},
{OO::MSE, 5.667504e-11, 3.080100e-05, 1.945e+03, {+9.99883373e-01, -3.30599535e-01, +1.81451316e-01, -8.71733830e-02, +2.18671936e-02}},
{OO::MSE, 1.202662e-12, 4.846916e-06, 3.318e+02, {+9.99980065e-01, -3.32694393e-01, +1.94019697e-01, -1.17694732e-01, +5.40822080e-02, -1.22995279e-02}},
{OO::MSE, 2.672889e-14, 7.722732e-07, 5.664e+01, {+9.99996589e-01, -3.33190090e-01, +1.98232868e-01, -1.32941469e-01, +8.07623712e-02, -3.46124853e-02, +7.15115276e-03}},
{OO::MSE, 6.147315e-16, 1.245768e-07, 9.764e+00, {+9.99999416e-01, -3.33302229e-01, +1.99511173e-01, -1.39332647e-01, +9.70944891e-02, -5.68823386e-02, +2.25679012e-02, -4.25772648e-03}},

{OO::MAE, 1.097847e-03, 4.801638e-02, 2.793e+06, {+8.33414544e-01}},
{OO::MAE, 1.209593e-05, 4.968992e-03, 4.623e+05, {+9.72410454e-01, -1.91981283e-01}},
{OO::MAE, 1.839382e-07, 6.107084e-04, 7.766e+04, {+9.95360080e-01, -2.88702052e-01, +7.93508437e-02}},
{OO::MAE, 3.296902e-09, 8.164167e-05, 1.313e+04, {+9.99214108e-01, -3.21178073e-01, +1.46272006e-01, -3.89915187e-02}},
{OO::MAE, 6.523525e-11, 1.147459e-05, 2.229e+03, {+9.99866373e-01, -3.30305517e-01, +1.80162434e-01, -8.51611537e-02, +2.08475020e-02}},
{OO::MAE, 1.378842e-12, 1.667328e-06, 3.792e+02, {+9.99977226e-01, -3.32622991e-01, +1.93541452e-01, -1.16429278e-01, +5.26504600e-02, -1.17203722e-02}},
{OO::MAE, 3.055131e-14, 2.480947e-07, 6.457e+01, {+9.99996113e-01, -3.33173716e-01, +1.98078484e-01, -1.32334692e-01, +7.96260166e-02, -3.36062649e-02, +6.81247117e-03}},
{OO::MAE, 7.013215e-16, 3.757868e-08, 1.102e+01, {+9.99999336e-01, -3.33298615e-01, +1.99465749e-01, -1.39086791e-01, +9.64233077e-02, -5.59142254e-02, +2.18643190e-02, -4.05495427e-03}},

{OO::MULPE, 1.355602e-03, 1.067325e-01, 1.808e+06, {+8.92130617e-01}},
{OO::MULPE, 2.100588e-05, 1.075508e-02, 1.822e+05, {+9.89111122e-01, -2.14468039e-01}},
{OO::MULPE, 3.573985e-07, 1.316370e-03, 2.227e+04, {+9.98665077e-01, -3.02990987e-01, +9.10404434e-02}},
{OO::MULPE, 6.474958e-09, 1.548508e-04, 2.619e+03, {+9.99842198e-01, -3.26272641e-01, +1.56294460e-01, -4.46207045e-02}},
{OO::MULPE, 1.313474e-10, 2.533532e-05, 4.294e+02, {+9.99974110e-01, -3.31823782e-01, +1.85886095e-01, -9.30024008e-02, +2.43894760e-02}},
{OO::MULPE, 3.007880e-12, 3.530685e-06, 5.983e+01, {+9.99996388e-01, -3.33036463e-01, +1.95959706e-01, -1.22068745e-01, +5.83403647e-02, -1.37966171e-02}},
{OO::MULPE, 6.348880e-14, 4.882649e-07, 8.276e+00, {+9.99999499e-01, -3.33273408e-01, +1.98895454e-01, -1.35153794e-01, +8.43185278e-02, -3.73434598e-02, +7.95583230e-03}},
{OO::MULPE, 1.369569e-15, 7.585036e-08, 1.284e+00, {+9.99999922e-01, -3.33320840e-01, +1.99708563e-01, -1.40257063e-01, +9.93094012e-02, -5.97138046e-02, +2.44056181e-02, -4.73371006e-03}},

{OO::MULPE_MAE, 9.548909e-04, 6.131488e-02, 2.570e+06, {+8.46713042e-01}},
{OO::MULPE_MAE, 1.159917e-05, 6.746680e-03, 3.778e+05, {+9.77449762e-01, -1.98798279e-01}},
{OO::MULPE_MAE, 1.783646e-07, 8.575388e-04, 6.042e+04, {+9.96388826e-01, -2.92591679e-01, +8.24585555e-02}},
{OO::MULPE_MAE, 3.265269e-09, 1.190548e-04, 9.505e+03, {+9.99430906e-01, -3.22774535e-01, +1.49370817e-01, -4.07480795e-02}},
{OO::MULPE_MAE, 6.574962e-11, 1.684690e-05, 1.515e+03, {+9.99909079e-01, -3.30795737e-01, +1.81810037e-01, -8.72860225e-02, +2.17776539e-02}},
{OO::MULPE_MAE, 1.380489e-12, 2.497538e-06, 2.510e+02, {+9.99984893e-01, -3.32748885e-01, +1.94193211e-01, -1.17865932e-01, +5.40633775e-02, -1.22309990e-02}},
{OO::MULPE_MAE, 3.053218e-14, 3.784868e-07, 4.181e+01, {+9.99997480e-01, -3.33205127e-01, +1.98309644e-01, -1.33094430e-01, +8.08643094e-02, -3.45859503e-02, +7.11261604e-03}},
{OO::MULPE_MAE, 7.018877e-16, 5.862915e-08, 6.942e+00, {+9.99999581e-01, -3.33306326e-01, +1.99542180e-01, -1.39433369e-01, +9.72462857e-02, -5.69734398e-02, +2.25639390e-02, -4.24074590e-03}},
};
} // namespace

const Approximation *find_best_approximation(const std::vector<Approximation> &table, ApproximationPrecision precision) {
const Approximation *best = nullptr;
constexpr int term_cost = 20;
constexpr int extra_term_cost = 200;
double best_score = 0;
// std::printf("Looking for min_terms=%d, max_absolute_error=%f\n", precision.constraint_min_poly_terms, precision.constraint_max_absolute_error);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove commented-out code

for (size_t i = 0; i < table.size(); ++i) {
const Approximation &e = table[i];

double penalty = 0.0;

int obj_score = e.objective == precision.optimized_for ? 100 * term_cost : 0;
if (precision.optimized_for == ApproximationPrecision::MULPE_MAE && e.objective == ApproximationPrecision::MULPE) {
obj_score = 50 * term_cost; // When MULPE_MAE is not available, prefer MULPE.
}

int num_terms = int(e.coefficients.size());
int term_count_score = (12 - num_terms) * term_cost;
if (num_terms < precision.constraint_min_poly_terms) {
penalty += (precision.constraint_min_poly_terms - num_terms) * extra_term_cost;
}

double precision_score = 0;
// If we don't care about the maximum number of terms, we maximize precision.
switch (precision.optimized_for) {
case ApproximationPrecision::MSE:
precision_score = -std::log(e.mse);
break;
case ApproximationPrecision::MAE:
precision_score = -std::log(e.mae);
break;
case ApproximationPrecision::MULPE:
precision_score = -std::log(e.mulpe);
break;
case ApproximationPrecision::MULPE_MAE:
precision_score = -0.5 * std::log(e.mulpe * e.mae);
break;
}

if (precision.constraint_max_absolute_error > 0.0 && precision.constraint_max_absolute_error < e.mae) {
float error_ratio = e.mae / precision.constraint_max_absolute_error;
penalty += 20 * error_ratio * extra_term_cost; // penalty for not getting the required precision.
}

double score = obj_score + term_count_score + precision_score - penalty;
// std::printf("Score for %zu (%zu terms): %f = %d + %d + %f - penalty %f\n", i, e.coefficients.size(), score, obj_score, term_count_score, precision_score, penalty);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove commented-out code

if (score > best_score || best == nullptr) {
best = &e;
best_score = score;
}
}
// std::printf("Best score: %f\n", best_score);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove commented-out code

return best;
}

const Approximation *best_atan_approximation(Halide::ApproximationPrecision precision) {
return find_best_approximation(table_atan, precision);
}

} // namespace Internal
} // namespace Halide
21 changes: 21 additions & 0 deletions src/ApproximationTables.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#pragma once
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Halide doesn't use #pragma once; instead, wrap in

#ifndef HALIDE_APPROXIMATION_TABLES_H_
#define HALIDE_APPROXIMATION_TABLES_H_
...

#endif


#include <vector>

#include "IROperator.h"

namespace Halide {
namespace Internal {

struct Approximation {
ApproximationPrecision::OptimizationObjective objective;
double mse;
double mae;
double mulpe;
std::vector<double> coefficients;
};

const Approximation *best_atan_approximation(Halide::ApproximationPrecision precision);

} // namespace Internal
} // namespace Halide
4 changes: 2 additions & 2 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,7 @@ target_sources(
WrapCalls.h
)

# The sources that go into libHalide. For the sake of IDE support, headers that
# exist in src/ but are not public should be included here.
# The sources that go into libHalide.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you alter the comment?

target_sources(
Halide
PRIVATE
Expand All @@ -232,6 +231,7 @@ target_sources(
AlignLoads.cpp
AllocationBoundsInference.cpp
ApplySplit.cpp
ApproximationTables.cpp
Argument.cpp
AssociativeOpsTable.cpp
Associativity.cpp
Expand Down
68 changes: 67 additions & 1 deletion src/IROperator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <sstream>
#include <utility>

#include "ApproximationTables.h"
#include "CSE.h"
#include "ConstantBounds.h"
#include "Debug.h"
Expand Down Expand Up @@ -1388,7 +1389,7 @@ Expr fast_sin_cos(const Expr &x_full, bool is_sin) {
Expr sin_usecos = is_sin ? ((k_mod4 == 1) || (k_mod4 == 3)) : ((k_mod4 == 0) || (k_mod4 == 2));
Expr flip_sign = is_sin ? (k_mod4 > 1) : ((k_mod4 == 1) || (k_mod4 == 2));

// Reduce the angle modulo pi/2.
// Reduce the angle modulo pi/2: i.e., to the angle within the quadrant.
Expr x = x_full - k_real * pi_over_two;

const float sin_c2 = -0.16666667163372039794921875f;
Expand Down Expand Up @@ -1425,6 +1426,71 @@ Expr fast_cos(const Expr &x_full) {
return fast_sin_cos(x_full, false);
}

// A vectorizable atan and atan2 implementation.
// Based on the ideas presented in https://mazzo.li/posts/vectorized-atan2.html.
Expr fast_atan_approximation(const Expr &x_full, ApproximationPrecision precision, bool between_m1_and_p1) {
const float pi_over_two = 1.57079632679489661923f;
Expr x;
// if x > 1 -> atan(x) = Pi/2 - atan(1/x)
Expr x_gt_1 = abs(x_full) > 1.0f;
if (between_m1_and_p1) {
x = x_full;
} else {
x = select(x_gt_1, 1.0f / x_full, x_full);
}

// Coefficients obtained using src/polynomial_optimizer.py
// Note that the maximal errors are computed with numpy with double precision.
// The real errors are a bit larger with single-precision floats (see correctness/fast_arctan.cpp).
// Also note that ULP distances which are not units are bogus, but this is because this error
// was again measured with double precision, so the actual reconstruction had more bits of precision
// than the actual float32 target value. So in practice the MaxULP Error will be close to round(MaxUlpE).

// The table is huge, so let's put clang-format off and handle the layout manually:
// clang-format off
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is there a clang-format:off here? I don't see a table, what is the comment referring to?

const Internal::Approximation *approx = Internal::best_atan_approximation(precision);
const std::vector<double> &c = approx->coefficients;

Expr x2 = x * x;
Expr result = float(c.back());
for (size_t i = 1; i < c.size(); ++i) {
result = x2 * result + float(c[c.size() - i - 1]);
}
result *= x;

if (!between_m1_and_p1) {
result = select(x_gt_1, select(x_full < 0, -pi_over_two, pi_over_two) - result, result);
}
return common_subexpression_elimination(result);
}
Expr fast_atan(const Expr &x_full, ApproximationPrecision precision) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Insert blank line

return fast_atan_approximation(x_full, precision, false);
}

Expr fast_atan2(const Expr &y, const Expr &x, ApproximationPrecision precision) {
const float pi = 3.14159265358979323846f;
const float pi_over_two = 1.57079632679489661923f;
// Making sure we take the ratio of the biggest number by the smallest number (in absolute value)
// will always give us a number between -1 and +1, which is the range over which the approximation
// works well. We can therefore also skip the inversion logic in the fast_atan_approximation function
// by passing true for "between_m1_and_p1". This increases both speed (1 division instead of 2) and
// numerical precision.
Expr swap = abs(y) > abs(x);
Expr atan_input = select(swap, x, y) / select(swap, y, x);
Expr ati = fast_atan_approximation(atan_input, precision, true);
Expr at = select(swap, select(atan_input >= 0.0f, pi_over_two, -pi_over_two) - ati, ati);
// This select statement is literally taken over from the definition on Wikipedia.
// There might be optimizations to be done here, but I haven't tried that yet. -- Martijn
Expr result = select(
x > 0.0f, at,
x < 0.0f && y >= 0.0f, at + pi,
x < 0.0f && y < 0.0f, at - pi,
x == 0.0f && y > 0.0f, pi_over_two,
x == 0.0f && y < 0.0f, -pi_over_two,
0.0f);
return common_subexpression_elimination(result);
}

Expr fast_exp(const Expr &x_full) {
user_assert(x_full.type() == Float(32)) << "fast_exp only works for Float(32)";

Expand Down
50 changes: 50 additions & 0 deletions src/IROperator.h
Original file line number Diff line number Diff line change
Expand Up @@ -972,6 +972,56 @@ Expr fast_sin(const Expr &x);
Expr fast_cos(const Expr &x);
// @}

/**
* Struct that allows the user to specify several requirements for functions
* that are approximated by polynomial expansions. These polynomials can be
* optimized for four different metrics: Mean Squared Error, Maximum Absolute Error,
* Maximum Units in Last Place (ULP) Error, or a 50%/50% blend of MAE and MULPE.
*
* Orthogonally to the optimization objective, these polynomials can vary
* in degree. Higher degree polynomials will give more precise results.
* Note that instead of specifying the degree, the number of terms is used instead.
* E.g., even symmetric functions may be implemented using only even powers, for which
* A number of terms of 4 would actually mean that terms in [1, x^2, x^4, x^6] are used,
* which is degree 6.
*
* Additionally, if you don't care about number of terms in the polynomial
* and you do care about the maximal absolute error the approximation may have
* over the domain, you may specify values and the implementation
* will decide the appropriate polynomial degree that achieves this precision.
*/
struct ApproximationPrecision {
enum OptimizationObjective {
MSE, //< Mean Squared Error Optimized.
MAE, //< Optimized for Max Absolute Error.
MULPE, //< Optimized for Max ULP Error. ULP is "Units in Last Place", measured in IEEE 32-bit floats.
MULPE_MAE, //< Optimized for simultaneously Max ULP Error, and Max Absolute Error, each with a weight of 50%.
} optimized_for;
int constraint_min_poly_terms{0}; //< Number of terms in polynomial (zero for no constraint).
float constraint_max_absolute_error{0.0f}; //< Max absolute error (zero for no constraint).
};

/** Fast vectorizable approximations for arctan and arctan2 for Float(32).
*
* Desired precision can be specified as either a maximum absolute error (MAE) or
* the number of terms in the polynomial approximation (see the ApproximationPrecision enum) which
* are optimized for either:
* - MSE (Mean Squared Error)
* - MAE (Maximum Absolute Error)
* - MULPE (Maximum Units in Last Place Error).
*
* The default (Max ULP Error Polynomial of 6 terms) has a MAE of 3.53e-6.
* For more info on the available approximations and their precisions, see the table in ApproximationTables.cpp.
*
* Note: the polynomial uses odd powers, so the number of terms is not the degree of the polynomial.
* Note: Poly8 is only useful to increase precision for atan, and not for atan2.
* Note: The performance of this functions seem to be not reliably faster on WebGPU (for now, August 2024).
*/
// @{
Expr fast_atan(const Expr &x, ApproximationPrecision precision = {ApproximationPrecision::MULPE, 6});
Expr fast_atan2(const Expr &y, const Expr &x, ApproximationPrecision = {ApproximationPrecision::MULPE, 6});
// @}

/** Fast approximate cleanly vectorizable log for Float(32). Returns
* nonsense for x <= 0.0f. Accurate up to the last 5 bits of the
* mantissa. Vectorizes cleanly. */
Expand Down
Loading
Loading