Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 04ebe45

Browse files
larroyanirudh2290
authored andcommitted
Prevent after-fork number of OMP threads being bigger than 1. (#16999)
* Prevent after-fork number of OMP threads being bigger than 1. This could happen if it was set in the environment. As we are setting engine::OpenMP::Get()->set_enabled(false) in initialize.cc in the child after forking, the behaviour goes back to what it was before #15762 was introduced. Regions using omp get the threads count from GetRecommendedOMPThreadCount, so if omp is disabled they will get 1 thread and run serially * add C++ unit test * Add comment
1 parent c82af38 commit 04ebe45

File tree

3 files changed

+97
-4
lines changed

3 files changed

+97
-4
lines changed

src/engine/openmp.cc

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,10 +90,11 @@ void OpenMP::set_reserve_cores(int cores) {
9090

9191
int OpenMP::GetRecommendedOMPThreadCount(bool exclude_reserved) const {
9292
#ifdef _OPENMP
93-
if (omp_num_threads_set_in_environment_) {
94-
return omp_get_max_threads();
95-
}
9693
if (enabled_) {
94+
// OMP_NUM_THREADS was set in the environment at the time of static initialization
95+
if (omp_num_threads_set_in_environment_) {
96+
return omp_get_max_threads();
97+
}
9798
int thread_count = omp_get_max_threads();
9899
if (exclude_reserved) {
99100
if (reserve_cores_ >= thread_count) {
@@ -107,8 +108,9 @@ int OpenMP::GetRecommendedOMPThreadCount(bool exclude_reserved) const {
107108
return thread_count;
108109
}
109110
return omp_thread_max_;
111+
} else {
112+
return 1;
110113
}
111-
return 1;
112114
#else
113115
return 1;
114116
#endif

tests/cpp/engine/omp_test.cc

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
#include <gtest/gtest.h>
21+
22+
#include "../include/test_util.h"
23+
#include "../../src/engine/openmp.h"
24+
25+
#if defined(unix) || defined(__unix__) || defined(__unix)
26+
#include <unistd.h>
27+
#include <sys/types.h>
28+
#include <dmlc/logging.h>
29+
30+
31+
TEST(OMPBehaviour, after_fork) {
32+
/*
33+
* Check that after fork, OMP is disabled, and the recommended thread count is 1 to prevent
34+
* process fanout.
35+
*/
36+
using namespace mxnet::engine;
37+
auto openmp = OpenMP::Get();
38+
pid_t pid = fork();
39+
if (pid == 0) {
40+
EXPECT_FALSE(openmp->enabled());
41+
EXPECT_EQ(openmp->GetRecommendedOMPThreadCount(), 1);
42+
} else if (pid > 0) {
43+
int status;
44+
int ret = waitpid(pid, &status, 0);
45+
CHECK_EQ(ret, pid) << "waitpid failed";
46+
} else {
47+
CHECK(false) << "fork failed";
48+
}
49+
}
50+
#endif

tests/python/unittest/test_engine.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717

1818
import nose
1919
import mxnet as mx
20+
import os
21+
import unittest
22+
from mxnet.test_utils import EnvManager
2023

2124
def test_bulk():
2225
with mx.engine.bulk(10):
@@ -30,6 +33,44 @@ def test_bulk():
3033
x += 1
3134
assert (x.asnumpy() == 104).all()
3235

36+
@unittest.skip("OMP platform dependent")
37+
def test_engine_openmp_after_fork():
38+
"""
39+
Test that the number of max threads in the child is 1. After forking we should not use a bigger
40+
OMP thread pool.
41+
42+
With GOMP the child always has the same number when calling omp_get_max_threads, with LLVM OMP
43+
the child respects the number of max threads set in the parent.
44+
"""
45+
with EnvManager('OMP_NUM_THREADS', '42'):
46+
r, w = os.pipe()
47+
pid = os.fork()
48+
if pid:
49+
os.close(r)
50+
wfd = os.fdopen(w, 'w')
51+
wfd.write('a')
52+
omp_max_threads = mx.base._LIB.omp_get_max_threads()
53+
print("Parent omp max threads: {}".format(omp_max_threads))
54+
try:
55+
wfd.close()
56+
except:
57+
pass
58+
try:
59+
(cpid, status) = os.waitpid(pid, 0)
60+
assert cpid == pid
61+
exit_status = status >> 8
62+
assert exit_status == 0
63+
except:
64+
pass
65+
else:
66+
os.close(w)
67+
rfd = os.fdopen(r, 'r')
68+
rfd.read(1)
69+
omp_max_threads = mx.base._LIB.omp_get_max_threads()
70+
print("Child omp max threads: {}".format(omp_max_threads))
71+
assert omp_max_threads == 1
72+
73+
3374

3475
if __name__ == '__main__':
3576
import nose

0 commit comments

Comments
 (0)