Skip to content

Commit f10cc89

Browse files
committed
configure the dummy train to not parallelize
1 parent 854d6e6 commit f10cc89

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

tests/test_module.py

+6
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,8 @@ def test_train(model_name, use_atomref, precision, tmpdir):
8888
def test_dummy_train(model_name, use_atomref, precision, tmpdir):
8989
import torch
9090

91+
torch.set_num_threads(1)
92+
9193
accelerator = "auto"
9294
if os.getenv("CPU_TRAIN", "false") == "true":
9395
# OSX MPS backend runs out of memory on Github Actions
@@ -111,6 +113,7 @@ def test_dummy_train(model_name, use_atomref, precision, tmpdir):
111113
num_rbf=4,
112114
batch_size=2,
113115
precision=precision,
116+
num_workers=0,
114117
**extra_args,
115118
)
116119
datamodule = DataModule(args, DummyDataset(has_atomref=use_atomref))
@@ -128,6 +131,9 @@ def test_dummy_train(model_name, use_atomref, precision, tmpdir):
128131
precision=args["precision"],
129132
inference_mode=False,
130133
accelerator=accelerator,
134+
num_nodes=1,
135+
devices=1,
136+
use_distributed_sampler=False,
131137
)
132138
trainer.fit(module, datamodule)
133139
trainer.test(module, datamodule)

0 commit comments

Comments
 (0)