Skip to content

Commit

Permalink
Fix portability/reproducibility of erfinv
Browse files Browse the repository at this point in the history
The tabulated normal distribution used to generate
velocities in grompp relies on erfinv(), which
we previously implemented with Newton-Raphson
iterations. This worked well on x86 that uses
denormal fp numbers, but it loses accuracy on Arm.
This replaces the erfinv functions with more
accurate implementations that at least appear to
generate identical results on different platforms,
which means the generated TPR files should also
be identical. It also tightens the tolerances
for all math functions so we are more likely to
detect similar issues in the future, and adjust
other reference values that appear to have been
produced at lower accuracy.

Fixes #4824.
  • Loading branch information
Erik Lindahl authored and mabraham committed Aug 14, 2023
1 parent f05373d commit 91e5d1a
Show file tree
Hide file tree
Showing 10 changed files with 529 additions and 300 deletions.
1 change: 1 addition & 0 deletions COPYING
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,7 @@ Erik Lindahl, 2008-10-07.

Files: src/gromacs/utility/current_function.h
src/gromacs/utility/path.cpp
src/gromacs/math/functions.cpp (erfinv algorithm & constants)

Boost Software License - Version 1.0 - August 17th, 2003

Expand Down
303 changes: 253 additions & 50 deletions src/gromacs/math/functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,93 +163,296 @@ std::int64_t greatestCommonDivisor(std::int64_t p, std::int64_t q)
return p;
}

double erfinv(double x)
// These inverse error function implementations are simplified versions
// of the algorithms and polynomia developed for Boost by
// John Maddock in 2006, and modified by Jeremy William Murphy in 2015.
//
// You might prefer the original version to avoid any bugs we have
// introduced, but in the spirit of not unnecessarily restricting a
// more liberal license, you are most welcome to also use and redistribute
// the Gromacs versions of these functions under the
// Boost Software License, Version 1.0. (http://www.boost.org/LICENSE_1_0.txt)

double erfinv(double arg)
{
double xabs = std::abs(x);
double x = std::abs(arg);

if (xabs > 1.0)
if (x > 1.0)
{
return std::nan("");
}

if (x == 1.0)
else if (arg == 1.0)
{
return std::numeric_limits<double>::infinity();
}

if (x == -1.0)
else if (arg == -1.0)
{
return -std::numeric_limits<double>::infinity();
}
else if (x == 0.0)
{
return 0.0;
}

double res;
double y = 1.0 - x;
double p, q, res;

if (xabs <= 0.7)
if (x <= 0.5)
{
// Rational approximation in range [0,0.7]
double z = x * x;
double P = (((-0.140543331 * z + 0.914624893) * z - 1.645349621) * z + 0.886226899);
double Q = ((((0.012229801 * z - 0.329097515) * z + 1.442710462) * z - 2.118377725) * z + 1.0);
res = x * P / Q;
// Rational approximation for |x| in range [0,0.5]
p = -0.00538772965071242932965;
p = p * x + 0.00822687874676915743155;
p = p * x + 0.0219878681111168899165;
p = p * x - 0.0365637971411762664006;
p = p * x - 0.0126926147662974029034;
p = p * x + 0.0334806625409744615033;
p = p * x - 0.00836874819741736770379;
p = p * x - 0.000508781949658280665617;

q = 0.000886216390456424707504;
q = q * x - 0.00233393759374190016776;
q = q * x + 0.0795283687341571680018;
q = q * x - 0.0527396382340099713954;
q = q * x - 0.71228902341542847553;
q = q * x + 0.662328840472002992063;
q = q * x + 1.56221558398423026363;
q = q * x - 1.56574558234175846809;
q = q * x - 0.970005043303290640362;
q = q * x + 1.0;

double t = x * (x + 10.0);
res = t * 0.0891314744949340820313 + t * p / q;
}
else if (x <= 0.75)
{
// Rational approx for |x| in range ]0.5,0.75]
double z = y - 0.25;
p = -3.67192254707729348546;
p = p * z + 21.1294655448340526258;
p = p * z + 17.445385985570866523;
p = p * z - 44.6382324441786960818;
p = p * z - 18.8510648058714251895;
p = p * z + 17.6447298408374015486;
p = p * z + 8.37050328343119927838;
p = p * z + 0.105264680699391713268;
p = p * z - 0.202433508355938759655;

q = 1.72114765761200282724;
q = q * z - 22.6436933413139721736;
q = q * z + 10.8268667355460159008;
q = q * z + 48.5609213108739935468;
q = q * z - 20.1432634680485188801;
q = q * z - 28.6608180499800029974;
q = q * z + 3.9713437953343869095;
q = q * z + 6.24264124854247537712;
q = q * z + 1.0;

double t = std::sqrt(-2.0 * std::log(y));
res = t / (2.249481201171875 + p / q);
}
else
{
// Rational approximation in range [0.7,1)
double z = std::sqrt(-std::log((1.0 - std::abs(x)) / 2.0));
double P = ((1.641345311 * z + 3.429567803) * z - 1.624906493) * z - 1.970840454;
double Q = (1.637067800 * z + 3.543889200) * z + 1.0;
res = std::copysign(1.0, x) * P / Q;
// Branch for 0.75 < x < 1 (meaning 0 < y <= 0.25)
double t = std::sqrt(-std::log(y));

if (t < 3.0)
{
// |x| in range ]0.75,0.99987659019591335]
double z = t - 1.125;
p = -0.681149956853776992068e-9;
p = p * z + 0.285225331782217055858e-7;
p = p * z - 0.679465575181126350155e-6;
p = p * z + 0.00214558995388805277169;
p = p * z + 0.0290157910005329060432;
p = p * z + 0.142869534408157156766;
p = p * z + 0.337785538912035898924;
p = p * z + 0.387079738972604337464;
p = p * z + 0.117030156341995252019;
p = p * z - 0.163794047193317060787;
p = p * z - 0.131102781679951906451;

q = 0.01105924229346489121;
q = q * z + 0.152264338295331783612;
q = q * z + 0.848854343457902036425;
q = q * z + 2.59301921623620271374;
q = q * z + 4.77846592945843778382;
q = q * z + 5.38168345707006855425;
q = q * z + 3.46625407242567245975;
q = q * z + 1.0;

res = t * 0.807220458984375 + t * p / q;
}
else
{
// |x| in range ]0.99987659019591335,1[
double z = t - 3.0;
p = 0.266339227425782031962e-11;
p = p * z - 0.230404776911882601748e-9;
p = p * z + 0.460469890584317994083e-5;
p = p * z + 0.000157544617424960554631;
p = p * z + 0.00187123492819559223345;
p = p * z + 0.00950804701325919603619;
p = p * z + 0.0185573306514231072324;
p = p * z - 0.00222426529213447927281;
p = p * z - 0.0350353787183177984712;

q = 0.764675292302794483503e-4;
q = q * z + 0.00263861676657015992959;
q = q * z + 0.0341589143670947727934;
q = q * z + 0.220091105764131249824;
q = q * z + 0.762059164553623404043;
q = q * z + 1.3653349817554063097;
q = q * z + 1.0;

res = t * 0.93995571136474609375 + t * p / q;
}
}

// Double precision requires two N-R iterations
res = res - (std::erf(res) - x) / ((2.0 / std::sqrt(M_PI)) * std::exp(-res * res));
res = res - (std::erf(res) - x) / ((2.0 / std::sqrt(M_PI)) * std::exp(-res * res));

return res;
return std::copysign(res, arg);
}

float erfinv(float x)
float erfinv(float arg)
{
float xabs = std::abs(x);
float x = std::abs(arg);

if (xabs > 1.0F)
if (x > 1.0F)
{
return std::nan("");
}

if (x == 1.0F)
else if (arg == 1.0F)
{
return std::numeric_limits<float>::infinity();
}

if (x == -1.0F)
else if (arg == -1.0F)
{
return -std::numeric_limits<float>::infinity();
}
else if (x == 0.0F)
{
return 0.0F;
}

float res;

if (xabs <= 0.7F)
float y = 1.0F - x;
float p, q, res;

// It is likely possible to use polynomia of slightly
// lower order in single precision, but to really
// optimize it would also require changing the intervals,
// and adopting factors to exact fp32 representation in
// IEEE-754. Given that we don't use erfinv() in any
// tight loops it's not needed for now, so we leave it
// as an exercise to the developer reading this note.
if (x <= 0.5F)
{
// Rational approximation in range [0,0.7]
float z = x * x;
float P = (((-0.140543331F * z + 0.914624893F) * z - 1.645349621F) * z + 0.886226899F);
float Q = ((((0.012229801F * z - 0.329097515F) * z + 1.442710462F) * z - 2.118377725F) * z + 1.0F);
res = x * P / Q;
// Rational approximation for |x| in range [0,0.5]
p = -0.00538772965071242932965F;
p = p * x + 0.00822687874676915743155F;
p = p * x + 0.0219878681111168899165F;
p = p * x - 0.0365637971411762664006F;
p = p * x - 0.0126926147662974029034F;
p = p * x + 0.0334806625409744615033F;
p = p * x - 0.00836874819741736770379F;
p = p * x - 0.000508781949658280665617F;

q = 0.000886216390456424707504F;
q = q * x - 0.00233393759374190016776F;
q = q * x + 0.0795283687341571680018F;
q = q * x - 0.0527396382340099713954F;
q = q * x - 0.71228902341542847553F;
q = q * x + 0.662328840472002992063F;
q = q * x + 1.56221558398423026363F;
q = q * x - 1.56574558234175846809F;
q = q * x - 0.970005043303290640362F;
q = q * x + 1.0F;

float t = x * (x + 10.0F);
res = t * 0.0891314744949340820313F + t * p / q;
}
else if (x <= 0.75F)
{
// Rational approx for |x| in range ]0.5,0.75]
float z = y - 0.25F;
p = -3.67192254707729348546F;
p = p * z + 21.1294655448340526258F;
p = p * z + 17.445385985570866523F;
p = p * z - 44.6382324441786960818F;
p = p * z - 18.8510648058714251895F;
p = p * z + 17.6447298408374015486F;
p = p * z + 8.37050328343119927838F;
p = p * z + 0.105264680699391713268F;
p = p * z - 0.202433508355938759655F;

q = 1.72114765761200282724F;
q = q * z - 22.6436933413139721736F;
q = q * z + 10.8268667355460159008F;
q = q * z + 48.5609213108739935468F;
q = q * z - 20.1432634680485188801F;
q = q * z - 28.6608180499800029974F;
q = q * z + 3.9713437953343869095F;
q = q * z + 6.24264124854247537712F;
q = q * z + 1.0F;

float t = std::sqrt(-2.0F * std::log(y));
res = t / (2.249481201171875F + p / q);
}
else
{
// Rational approximation in range [0.7,1)
float z = std::sqrt(-std::log((1.0 - std::abs(x)) / 2.0F));
float P = ((1.641345311F * z + 3.429567803F) * z - 1.624906493F) * z - 1.970840454F;
float Q = (1.637067800F * z + 3.543889200F) * z + 1.0F;
res = std::copysign(1.0F, x) * P / Q;
// Branch for 0.75 < x < 1 (meaning 0 < y <= 0.25)
float t = std::sqrt(-std::log(y));

if (t < 3.0F)
{
// |x| in range ]0.75,0.99987659019591335]
float z = t - 1.125F;
p = -0.681149956853776992068e-9F;
p = p * z + 0.285225331782217055858e-7F;
p = p * z - 0.679465575181126350155e-6F;
p = p * z + 0.00214558995388805277169F;
p = p * z + 0.0290157910005329060432F;
p = p * z + 0.142869534408157156766F;
p = p * z + 0.337785538912035898924F;
p = p * z + 0.387079738972604337464F;
p = p * z + 0.117030156341995252019F;
p = p * z - 0.163794047193317060787F;
p = p * z - 0.131102781679951906451F;

q = 0.01105924229346489121F;
q = q * z + 0.152264338295331783612F;
q = q * z + 0.848854343457902036425F;
q = q * z + 2.59301921623620271374F;
q = q * z + 4.77846592945843778382F;
q = q * z + 5.38168345707006855425F;
q = q * z + 3.46625407242567245975F;
q = q * z + 1.0F;

res = t * 0.807220458984375F + t * p / q;
}
else
{
// |x| in range ]0.99987659019591335,1[
float z = t - 3.0F;
p = 0.266339227425782031962e-11F;
p = p * z - 0.230404776911882601748e-9F;
p = p * z + 0.460469890584317994083e-5F;
p = p * z + 0.000157544617424960554631F;
p = p * z + 0.00187123492819559223345F;
p = p * z + 0.00950804701325919603619F;
p = p * z + 0.0185573306514231072324F;
p = p * z - 0.00222426529213447927281F;
p = p * z - 0.0350353787183177984712F;

q = 0.764675292302794483503e-4F;
q = q * z + 0.00263861676657015992959F;
q = q * z + 0.0341589143670947727934F;
q = q * z + 0.220091105764131249824F;
q = q * z + 0.762059164553623404043F;
q = q * z + 1.3653349817554063097F;
q = q * z + 1.0F;

res = t * 0.93995571136474609375F + t * p / q;
}
}

// Single N-R iteration sufficient for single precision
res = res - (std::erf(res) - x) / ((2.0F / std::sqrt(M_PI)) * std::exp(-res * res));

return res;
return std::copysign(res, arg);
}


} // namespace gmx
Loading

0 comments on commit 91e5d1a

Please sign in to comment.