From 944770381df5f53445de49d4a9398189d07425a8 Mon Sep 17 00:00:00 2001 From: Geolm Date: Thu, 25 Jan 2024 21:23:07 -0500 Subject: [PATCH] log2 compliance --- math_intrinsics.h | 14 +++++++------- tests/test.c | 16 ++++++++++++++-- 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/math_intrinsics.h b/math_intrinsics.h index af2c121..d90a328 100644 --- a/math_intrinsics.h +++ b/math_intrinsics.h @@ -429,6 +429,10 @@ static inline simd_vector simd_sign(simd_vector a) __m256 mm256_log2_ps(__m256 x) #endif { + simd_vector invalid_mask = simd_cmp_le(x, simd_splat_zero()); + invalid_mask = simd_or(invalid_mask, simd_isnan(x)); + simd_vector input_is_zero = simd_cmp_eq(x, simd_splat_zero()); + simd_vector input_is_infinity = simd_cmp_eq(x, simd_splat_positive_infinity()); simd_vector one = simd_splat(1.f); simd_vectori exp = simd_splat_i(0x7f800000); simd_vectori mant = simd_splat_i(0x007fffff); @@ -443,13 +447,9 @@ static inline simd_vector simd_sign(simd_vector a) p = simd_mul(p, simd_sub(m, one)); simd_vector result = simd_add(p, e); - // we can't compute a logarithm beyond this value, so we'll mark it as -infinity to indicate close to 0 - simd_vector ltminus127 = simd_cmp_le(result, simd_splat(-127.f)); - result = simd_select(result, simd_splat_negative_infinity(), ltminus127); - - // Check for negative values and return NaN - simd_vector lt0 = simd_cmp_lt(x, simd_splat_zero()); - result = simd_select(result, simd_splat_nan(), lt0); + result = simd_or(result, invalid_mask); // NAN/negative arg will be NAN + result = simd_select(result, simd_splat_negative_infinity(), input_is_zero); // zero arg will be -inf + result = simd_select(result, simd_splat_positive_infinity(), input_is_infinity); // +inf arg will be +inf return result; } diff --git a/tests/test.c b/tests/test.c index afee20e..7dcee33 100644 --- a/tests/test.c +++ b/tests/test.c @@ -113,8 +113,8 @@ TEST nan_expected(float input, approximation_function function) ASSERT(isnan(result)); #else float32x4_t v_input = vdupq_n_f32(input); - float32x4_t v_result = function(v_input); - ASSERT(vmaxvq_u32(vcleq_f32(v_result, v_result)) == 0); + float result = vgetq_lane_f32(function(v_input), 0); + ASSERT(isnan(result)); #endif PASS(); @@ -204,12 +204,24 @@ SUITE(infinity_nan_compliance) RUN_TESTp(value_expected, 1.f, 0.f, mm256_log_ps); RUN_TESTp(value_expected, 0.f, negative_inf, mm256_log_ps); RUN_TESTp(value_expected, positive_inf, positive_inf, mm256_log_ps); + + RUN_TESTp(nan_expected, -1.f, mm256_log2_ps); + RUN_TESTp(nan_expected, not_a_number, mm256_log2_ps); + RUN_TESTp(value_expected, 1.f, 0.f, mm256_log2_ps); + RUN_TESTp(value_expected, 0.f, negative_inf, mm256_log2_ps); + RUN_TESTp(value_expected, positive_inf, positive_inf, mm256_log2_ps); #else RUN_TESTp(nan_expected, -1.f, vlogq_f32); RUN_TESTp(nan_expected, not_a_number, vlogq_f32); RUN_TESTp(value_expected, 1.f, 0.f, vlogq_f32); RUN_TESTp(value_expected, 0.f, negative_inf, vlogq_f32); RUN_TESTp(value_expected, positive_inf, positive_inf, vlogq_f32); + + RUN_TESTp(nan_expected, -1.f, vlog2q_f32); + RUN_TESTp(nan_expected, not_a_number, vlog2q_f32); + RUN_TESTp(value_expected, 1.f, 0.f, vlog2q_f32); + RUN_TESTp(value_expected, 0.f, negative_inf, vlog2q_f32); + RUN_TESTp(value_expected, positive_inf, positive_inf, vlog2q_f32); #endif }