Skip to content

Commit

Permalink
started to handle properly inf/nan as input
Browse files Browse the repository at this point in the history
  • Loading branch information
Geolm committed Jan 26, 2024
1 parent c82b2aa commit 99455c8
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 1 deletion.
12 changes: 11 additions & 1 deletion math_intrinsics.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ extern "C" {
static inline simd_vector simd_cmp_lt(simd_vector a, simd_vector b) {return vreinterpretq_f32_u32(vcltq_f32(a, b));}
static inline simd_vector simd_cmp_le(simd_vector a, simd_vector b) {return vreinterpretq_f32_u32(vcleq_f32(a, b));}
static inline simd_vector simd_cmp_eq(simd_vector a, simd_vector b) {return vreinterpretq_f32_u32(vceqq_f32(a, b));}
static inline simd_vector simd_cmp_neq(simd_vector a, simd_vector b) {return vreinterpretq_f32_u32(vmvnq_u32(vceqq_f32(a, b)));}
static inline simd_vector simd_isnan(simd_vector a) {return simd_cmp_neq(a, a);}
static inline simd_vector simd_select(simd_vector a, simd_vector b, simd_vector mask) {return vbslq_f32(vreinterpretq_u32_f32(mask), b, a);}
static inline simd_vector simd_splat(float value) {return vdupq_n_f32(value);}
static inline simd_vector simd_splat_zero(void) {return vdupq_n_f32(0);}
Expand Down Expand Up @@ -210,6 +212,8 @@ extern "C" {
static inline simd_vector simd_cmp_lt(simd_vector a, simd_vector b) {return _mm256_cmp_ps(a, b, _CMP_LT_OQ);}
static inline simd_vector simd_cmp_le(simd_vector a, simd_vector b) {return _mm256_cmp_ps(a, b, _CMP_LE_OQ);}
static inline simd_vector simd_cmp_eq(simd_vector a, simd_vector b) {return _mm256_cmp_ps(a, b, _CMP_EQ_OQ);}
static inline simd_vector simd_cmp_neq(simd_vector a, simd_vector b) {return _mm256_cmp_ps(a, b, _CMP_NEQ_OQ);}
static inline simd_vector simd_isnan(simd_vector a) {return _mm256_cmp_ps(a, a, _CMP_NEQ_UQ);}
static inline simd_vector simd_sqrt(simd_vector a) {return _mm256_sqrt_ps(a);}
static inline simd_vector simd_neg(simd_vector a) {return _mm256_xor_ps(a, simd_sign_mask());}
static inline simd_vector simd_rcp(simd_vector a) {return _mm256_rcp_ps(a);}
Expand Down Expand Up @@ -367,6 +371,10 @@ static inline simd_vector simd_sign(simd_vector a)
{
simd_vector one = simd_splat(1.f);
simd_vector invalid_mask = simd_cmp_le(x, simd_splat_zero());
invalid_mask = simd_or(invalid_mask, simd_isnan(x));
simd_vector input_is_zero = simd_cmp_eq(x, simd_splat_zero());
simd_vector input_is_infinity = simd_cmp_eq(x, simd_splat_positive_infinity());

x = simd_max(x, simd_min_normalized()); // cut off denormalized stuff

simd_vectori emm0 = simd_shift_right_i(simd_cast_from_float(x), 23);
Expand Down Expand Up @@ -406,7 +414,9 @@ static inline simd_vector simd_sign(simd_vector a)
tmp = simd_mul(e, simd_splat(0.693359375f));
x = simd_add(x, y);
x = simd_add(x, tmp);
x = simd_or(x, invalid_mask); // negative arg will be NAN
x = simd_or(x, invalid_mask); // NAN/negative arg will be NAN
x = simd_select(x, simd_splat_negative_infinity(), input_is_zero); // zero arg will be -inf
x = simd_select(x, simd_splat_positive_infinity(), input_is_infinity); // +inf arg will be +inf

return x;
}
Expand Down
58 changes: 58 additions & 0 deletions tests/test.c
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,39 @@ TEST generic_test(reference_function ref, approximation_function approx, float r
PASS();
}

//----------------------------------------------------------------------------------------------------------------------
TEST value_expected(float input, float target, approximation_function function)
{
#ifdef __MATH__INTRINSICS__AVX__
__m256 v_input = _mm256_set1_ps(input);
float result = _mm256_cvtss_f32(function(v_input));
ASSERT_EQ_FMT(target, result, "%f");
#else
float32x4_t v_input = vdupq_n_f32(input);
float result = vgetq_lane_f32(function(v_input), 0);
ASSERT_EQ_FMT(target, result, "%f");
#endif

PASS();
}

//----------------------------------------------------------------------------------------------------------------------
TEST nan_expected(float input, approximation_function function)
{
#ifdef __MATH__INTRINSICS__AVX__
__m256 v_input = _mm256_set1_ps(input);
float result = _mm256_cvtss_f32(function(v_input));
ASSERT(isnan(result));
#else
float32x4_t v_input = vdupq_n_f32(input);
float32x4_t v_result = function(v_input);
ASSERT(vmaxvq_u32(vcleq_f32(v_result, v_result)) == 0);
#endif

PASS();
}


//----------------------------------------------------------------------------------------------------------------------
float atan2_angle(float angle) {return atan2f(sinf(angle) * (angle + 1.f), cosf(angle) * (angle + 1.f));}

Expand Down Expand Up @@ -159,6 +192,27 @@ SUITE(exponentiation)
#endif
}

SUITE(infinity_nan_compliant)
{
const float positive_inf = 1.f / 0.f;
const float negative_inf = -1.f / 0.f;
const float not_a_number = 0.f / 0.f;

#ifdef __MATH__INTRINSICS__AVX__
RUN_TESTp(nan_expected, -1.f, mm256_log_ps);
RUN_TESTp(nan_expected, not_a_number, mm256_log_ps);
RUN_TESTp(value_expected, 1.f, 0.f, mm256_log_ps);
RUN_TESTp(value_expected, 0.f, negative_inf, mm256_log_ps);
RUN_TESTp(value_expected, positive_inf, positive_inf, mm256_log_ps);
#else
RUN_TESTp(nan_expected, -1.f, vlogq_f32);
RUN_TESTp(nan_expected, not_a_number, vlogq_f32);
RUN_TESTp(value_expected, 1.f, 0.f, vlogq_f32);
RUN_TESTp(value_expected, 0.f, negative_inf, vlogq_f32);
RUN_TESTp(value_expected, positive_inf, positive_inf, vlogq_f32);
#endif
}

GREATEST_MAIN_DEFS();

int main(int argc, char * argv[])
Expand All @@ -167,7 +221,11 @@ int main(int argc, char * argv[])

RUN_SUITE(trigonometry);
RUN_SUITE(exponentiation);
RUN_SUITE(infinity_nan_compliant);

GREATEST_MAIN_END();

(void)nan_expected;
(void)value_expected;
}

0 comments on commit 99455c8

Please sign in to comment.