Skip to content

Commit 80a82b6

Browse files
committed
Prevent after-fork number of OMP threads being bigger than 1.
This could happen if it was set in the environment. As we are setting engine::OpenMP::Get()->set_enabled(false) in initialize.cc in the child after forking, the behaviour goes back to what it was before apache#15762 was introduced. Regions using omp get the threads count from GetRecommendedOMPThreadCount, so if omp is disabled they will get 1 thread and run serially
1 parent 7fc6255 commit 80a82b6

File tree

2 files changed

+46
-4
lines changed

2 files changed

+46
-4
lines changed

src/engine/openmp.cc

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,10 @@ void OpenMP::set_reserve_cores(int cores) {
8383

8484
int OpenMP::GetRecommendedOMPThreadCount(bool exclude_reserved) const {
8585
#ifdef _OPENMP
86-
if (omp_num_threads_set_in_environment_) {
87-
return omp_get_max_threads();
88-
}
8986
if (enabled_) {
87+
if (omp_num_threads_set_in_environment_) {
88+
return omp_get_max_threads();
89+
}
9090
int thread_count = omp_get_max_threads();
9191
if (exclude_reserved) {
9292
if (reserve_cores_ >= thread_count) {
@@ -100,8 +100,9 @@ int OpenMP::GetRecommendedOMPThreadCount(bool exclude_reserved) const {
100100
return thread_count;
101101
}
102102
return omp_thread_max_;
103+
} else {
104+
return 1;
103105
}
104-
return 1;
105106
#else
106107
return 1;
107108
#endif

tests/python/unittest/test_engine.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717

1818
import nose
1919
import mxnet as mx
20+
import os
21+
import unittest
22+
from mxnet.test_utils import EnvManager
2023

2124
def test_bulk():
2225
with mx.engine.bulk(10):
@@ -30,6 +33,44 @@ def test_bulk():
3033
x += 1
3134
assert (x.asnumpy() == 104).all()
3235

36+
@unittest.skip("OMP platform dependent")
37+
def test_engine_openmp_after_fork():
38+
"""
39+
Test that the number of max threads in the child is 1. After forking we should not use a bigger
40+
OMP thread pool.
41+
42+
With GOMP the child always has the same number when calling omp_get_max_threads, with LLVM OMP
43+
the child respects the number of max threads set in the parent.
44+
"""
45+
with EnvManager('OMP_NUM_THREADS', '42'):
46+
r, w = os.pipe()
47+
pid = os.fork()
48+
if pid:
49+
os.close(r)
50+
wfd = os.fdopen(w, 'w')
51+
wfd.write('a')
52+
omp_max_threads = mx.base._LIB.omp_get_max_threads()
53+
print("Parent omp max threads: {}".format(omp_max_threads))
54+
try:
55+
wfd.close()
56+
except:
57+
pass
58+
try:
59+
(cpid, status) = os.waitpid(pid, 0)
60+
assert cpid == pid
61+
exit_status = status >> 8
62+
assert exit_status == 0
63+
except:
64+
pass
65+
else:
66+
os.close(w)
67+
rfd = os.fdopen(r, 'r')
68+
rfd.read(1)
69+
omp_max_threads = mx.base._LIB.omp_get_max_threads()
70+
print("Child omp max threads: {}".format(omp_max_threads))
71+
assert omp_max_threads == 1
72+
73+
3374

3475
if __name__ == '__main__':
3576
import nose

0 commit comments

Comments
 (0)