Skip to content

[ENH] Control number of threads with threadpoolctl #537

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Mar 25, 2020
Merged
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ numpy>=1.15
pandas
scikit-learn>=0.22
scipy>=1.3.3
threadpoolctl
3 changes: 2 additions & 1 deletion tedana/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@
'nibabel>=2.1.0',
'scipy',
'pandas',
'matplotlib'
'matplotlib',
'threadpoolctl'
]

TESTS_REQUIRES = [
Expand Down
15 changes: 14 additions & 1 deletion tedana/workflows/t2smap.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import argparse
import numpy as np
from scipy import stats
from threadpoolctl import threadpool_limits

from tedana import (combine, decay, io, utils)
from tedana.workflows.parser_utils import is_valid_file
Expand Down Expand Up @@ -89,6 +90,14 @@ def _get_parser():
'demanding monoexponential model is fit'
'to the raw data',
default='loglin')
optional.add_argument('--n-threads',
dest='n_threads',
type=int,
action='store',
help=('Number of threads to use. Used by '
'threadcountctl to set the parameter outside '
'of the workflow function.'),
default=-1)
optional.add_argument('--debug',
dest='debug',
help=argparse.SUPPRESS,
Expand Down Expand Up @@ -248,7 +257,11 @@ def _main(argv=None):
else:
logging.basicConfig(level=logging.INFO)

t2smap_workflow(**vars(options))
kwargs = vars(options)
n_threads = kwargs.pop('n_threads')
n_threads = None if n_threads == -1 else n_threads
with threadpool_limits(limits=n_threads, user_api=None):
t2smap_workflow(**vars(options))


if __name__ == '__main__':
Expand Down
26 changes: 16 additions & 10 deletions tedana/workflows/tedana.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,17 @@
Run the "canonical" TE-Dependent ANAlysis workflow.
"""
import os

os.environ['MKL_NUM_THREADS'] = '1'
os.environ['NUMEXPR_NUM_THREADS'] = '1'
os.environ['OMP_NUM_THREADS'] = '1'
os.environ['VECLIB_MAXIMUM_THREADS'] = '1'
os.environ['OPENBLAS_NUM_THREADS'] = '1'

import os.path as op
import shutil
import logging
import os.path as op
from glob import glob
import datetime
from glob import glob

import argparse
import numpy as np
import pandas as pd
from scipy import stats
from threadpoolctl import threadpool_limits
from nilearn.masking import compute_epi_mask

from tedana import (decay, combine, decomposition, io, metrics, selection,
Expand Down Expand Up @@ -177,6 +171,14 @@ def _get_parser():
'use of IncrementalPCA. May increase workflow '
'duration.'),
default=False)
optional.add_argument('--n-threads',
dest='n_threads',
type=int,
action='store',
help=('Number of threads to use. Used by '
'threadcountctl to set the parameter outside '
'of the workflow function.'),
default=-1)
optional.add_argument('--debug',
dest='debug',
action='store_true',
Expand Down Expand Up @@ -669,7 +671,11 @@ def tedana_workflow(data, tes, out_dir='.', mask=None,
def _main(argv=None):
"""Tedana entry point"""
options = _get_parser().parse_args(argv)
tedana_workflow(**vars(options))
kwargs = vars(options)
n_threads = kwargs.pop('n_threads')
n_threads = None if n_threads == -1 else n_threads
with threadpool_limits(limits=n_threads, user_api=None):
tedana_workflow(**kwargs)


if __name__ == '__main__':
Expand Down