Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit 321cf91

Browse files
authored
Clarify data_parallel implementation. (#2488)
- It appears we're unnecessarily replicating the model to the source device, but this is not the case. - Silence some deprecation warnings that were missed previously.
1 parent 540ebac commit 321cf91

File tree

2 files changed

+6
-0
lines changed

2 files changed

+6
-0
lines changed

allennlp/training/util.py

+3
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,10 @@ def data_parallel(batch_group: List[TensorDict],
237237
for batch, device in zip(batch_group, cuda_devices)]
238238

239239
used_device_ids = cuda_devices[:len(moved)]
240+
# Counterintuitively, it appears replicate expects the source device id to be the first element
241+
# in the device id list. See torch.cuda.comm.broadcast_coalesced, which is called indirectly.
240242
replicas = replicate(model, used_device_ids)
243+
241244
# We pass all our arguments as kwargs. Create a list of empty tuples of the
242245
# correct shape to serve as (non-existent) positional arguments.
243246
inputs = [()] * len(batch_group)

pytest.ini

+3
Original file line numberDiff line numberDiff line change
@@ -28,5 +28,8 @@ filterwarnings =
2828
ignore:@asynchronous is deprecated, use coroutines instead::tornado\.web
2929
#
3030
ignore:encoding is deprecated, Use raw=False instead.:PendingDeprecationWarning:msgpack_numpy
31+
# For `spacy==2.0.11`
32+
ignore:Direct calling implementation's unpack.*:PendingDeprecationWarning:msgpack_numpy
33+
ignore:The binary mode of fromstring is deprecated.*:DeprecationWarning:msgpack_numpy
3134
# 4. ignore these `ExperimentalFeatureWarning` for now, but record them once
3235
once:This particular transformer implementation is a provisional feature.*::allennlp\.modules\.seq2seq_encoders\.bidirectional_language_model_transformer

0 commit comments

Comments
 (0)