Skip to content

Commit 283dff4

Browse files
authored
[compiler-rt][nsan] Add support for nan detection (#101531)
Add support for nan detection. #100305
1 parent da6f423 commit 283dff4

File tree

6 files changed

+166
-0
lines changed

6 files changed

+166
-0
lines changed

compiler-rt/lib/nsan/nsan.cpp

+26
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,32 @@ int32_t checkFT(const FT value, ShadowFT Shadow, CheckTypeT CheckType,
445445
const InternalFT check_value = value;
446446
const InternalFT check_shadow = Shadow;
447447

448+
// We only check for NaNs in the value, not the shadow.
449+
if (flags().check_nan && isnan(check_value)) {
450+
GET_CALLER_PC_BP;
451+
BufferedStackTrace stack;
452+
stack.Unwind(pc, bp, nullptr, false);
453+
if (GetSuppressionForStack(&stack, CheckKind::Consistency)) {
454+
// FIXME: optionally print.
455+
return flags().resume_after_suppression ? kResumeFromValue
456+
: kContinueWithShadow;
457+
}
458+
Decorator D;
459+
Printf("%s", D.Warning());
460+
Printf("WARNING: NumericalStabilitySanitizer: NaN detected\n");
461+
Printf("%s", D.Default());
462+
stack.Print();
463+
if (flags().halt_on_error) {
464+
if (common_flags()->abort_on_error)
465+
Printf("ABORTING\n");
466+
else
467+
Printf("Exiting\n");
468+
Die();
469+
}
470+
// Performing other tests for NaN values is meaningless when dealing with numbers.
471+
return kResumeFromValue;
472+
}
473+
448474
// See this article for an interesting discussion of how to compare floats:
449475
// https://randomascii.wordpress.com/2012/02/25/comparing-floating-point-numbers-2012-edition/
450476
static constexpr const FT Eps = FTInfo<FT>::kEpsilon;

compiler-rt/lib/nsan/nsan_flags.inc

+2
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,5 @@ NSAN_FLAG(bool, enable_loadtracking_stats, false,
4848
"due to invalid or unknown types.")
4949
NSAN_FLAG(bool, poison_in_free, true, "")
5050
NSAN_FLAG(bool, print_stats_on_exit, false, "If true, print stats on exit.")
51+
NSAN_FLAG(bool, check_nan, false,
52+
"If true, check the floating-point number is nan")

compiler-rt/test/nsan/nan.cpp

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// RUN: %clangxx_nsan -O0 -g %s -o %t
2+
// RUN: NSAN_OPTIONS=check_nan=true,halt_on_error=0 %run %t 2>&1 | FileCheck %s
3+
4+
// RUN: %clangxx_nsan -O3 -g %s -o %t
5+
// RUN: NSAN_OPTIONS=check_nan=true,halt_on_error=0 %run %t 2>&1 | FileCheck %s
6+
7+
// RUN: %clangxx_nsan -O0 -g %s -o %t
8+
// RUN: NSAN_OPTIONS=check_nan=true,halt_on_error=1 not %run %t
9+
10+
#include <cmath>
11+
#include <cstdio>
12+
13+
// This function returns a NaN value for triggering the NaN detection.
14+
__attribute__((noinline)) float ReturnNaN(float p, float q) {
15+
float ret = p / q;
16+
return ret;
17+
// CHECK: WARNING: NumericalStabilitySanitizer: NaN detected
18+
}
19+
20+
int main() {
21+
float val = ReturnNaN(0., 0.);
22+
printf("%f\n", val);
23+
// CHECK: WARNING: NumericalStabilitySanitizer: NaN detected
24+
return 0;
25+
}

compiler-rt/test/nsan/softmax.cpp

+54
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
// RUN: %clangxx_nsan -O0 -g -DSOFTMAX=softmax %s -o %t
2+
// RUN: NSAN_OPTIONS=check_nan=true,halt_on_error=0,log2_max_relative_error=19 %run %t 2>&1 | FileCheck %s
3+
4+
// RUN: %clangxx_nsan -O3 -g -DSOFTMAX=softmax %s -o %t
5+
// RUN: NSAN_OPTIONS=check_nan=true,halt_on_error=0,log2_max_relative_error=19 %run %t 2>&1 | FileCheck %s
6+
7+
// RUN: %clangxx_nsan -O0 -g -DSOFTMAX=stable_softmax %s -o %t
8+
// RUN: NSAN_OPTIONS=check_nan=true,halt_on_error=1,log2_max_relative_error=19 %run %t
9+
10+
// RUN: %clangxx_nsan -O3 -g -DSOFTMAX=stable_softmax %s -o %t
11+
// RUN: NSAN_OPTIONS=check_nan=true,halt_on_error=1,log2_max_relative_error=19 %run %t
12+
13+
#include<iostream>
14+
#include<vector>
15+
#include<algorithm>
16+
#include<cmath>
17+
18+
// unstable softmax
19+
template <typename T>
20+
__attribute__((noinline)) void softmax(std::vector<T> &values) {
21+
T sum_exp = 0.0;
22+
for (auto &i: values) {
23+
i = std::exp(i);
24+
sum_exp += i;
25+
}
26+
for (auto &i: values) {
27+
i /= sum_exp;
28+
}
29+
}
30+
31+
// use max value to avoid overflow
32+
// \sigma_i exp(x_i) / \sum_j exp(x_j) = \sigma_i exp(x_i - max(x)) / \sum_j exp(x_j - max(x))
33+
template <typename T>
34+
__attribute__((noinline)) void stable_softmax(std::vector<T> &values) {
35+
T sum_exp = 0.0;
36+
T max_values = *std::max_element(values.begin(), values.end());
37+
for (auto &i: values) {
38+
i = std::exp(i - max_values);
39+
sum_exp += i;
40+
}
41+
for (auto &i:values) {
42+
i /= sum_exp;
43+
}
44+
}
45+
46+
int main() {
47+
std::vector<double> data = {1000, 1001, 1002};
48+
SOFTMAX(data);
49+
for (auto i: data) {
50+
printf("%f", i);
51+
// CHECK: WARNING: NumericalStabilitySanitizer: NaN detected
52+
}
53+
return 0;
54+
}

compiler-rt/test/nsan/vec_sqrt.cpp

+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// RUN: %clangxx_nsan -O0 -g -mavx %s -o %t
2+
// RUN: NSAN_OPTIONS=check_nan=true,halt_on_error=0 %run %t 2>&1 | FileCheck %s
3+
// RUN: %clangxx_nsan -O3 -g -mavx %s -o %t
4+
// RUN: NSAN_OPTIONS=check_nan=true,halt_on_error=0 %run %t 2>&1 | FileCheck %s
5+
6+
#include <cmath>
7+
#include <immintrin.h>
8+
#include <iostream>
9+
10+
void simd_sqrt(const float *input, float *output, size_t size) {
11+
size_t i = 0;
12+
for (; i + 7 < size; i += 8) {
13+
__m256 vec = _mm256_loadu_ps(&input[i]);
14+
__m256 result = _mm256_sqrt_ps(vec);
15+
_mm256_storeu_ps(&output[i], result);
16+
}
17+
for (; i < size; ++i) {
18+
output[i] = std::sqrt(input[i]);
19+
// CHECK: WARNING: NumericalStabilitySanitizer: NaN detected
20+
}
21+
}
22+
23+
int main() {
24+
float input[] = {1.0, 2.0, -3.0, 4.0, 5.0, 6.0, 7.0,
25+
8.0, 9.0, -10.0, 11.0, 12.0, 13.0, 14.0,
26+
15.0, -16.0, 17.0, -18.0, -19.0, -20.0};
27+
float output[20];
28+
simd_sqrt(input, output, 20);
29+
for (int i = 0; i < 20; ++i) {
30+
std::cout << output[i] << std::endl;
31+
// CHECK: WARNING: NumericalStabilitySanitizer: NaN detected
32+
}
33+
return 0;
34+
}
+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// RUN: %clangxx_nsan -O0 -g -mavx %s -o %t
2+
// RUN: NSAN_OPTIONS=check_nan=true,halt_on_error=0 %run %t 2>&1 | FileCheck %s
3+
// RUN: %clangxx_nsan -O3 -g -mavx %s -o %t
4+
// RUN: NSAN_OPTIONS=check_nan=true,halt_on_error=0 %run %t 2>&1 | FileCheck %s
5+
#include <iostream>
6+
#include <cmath>
7+
8+
typedef float v8sf __attribute__ ((vector_size(32)));
9+
10+
v8sf simd_sqrt(v8sf a) {
11+
return __builtin_elementwise_sqrt(a);
12+
// CHECK: WARNING: NumericalStabilitySanitizer: NaN detected
13+
}
14+
15+
int main() {
16+
v8sf a = {-1.0, -2.0, -3.0, 4.0, 5.0, 6.0, 7.0, 8.0};
17+
a = simd_sqrt(a);
18+
19+
// This prevents DCE.
20+
for (size_t i = 0; i < 8; ++i) {
21+
std::cout << a[i] << std::endl;
22+
// CHECK: WARNING: NumericalStabilitySanitizer: NaN detected
23+
}
24+
return 0;
25+
}

0 commit comments

Comments
 (0)