Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Make dist_reduce work for tensors #5147

Merged
merged 3 commits into from
Apr 23, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions allennlp/nn/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2016,7 +2016,7 @@ def tiny_value_of_dtype(dtype: torch.dtype):
raise TypeError("Does not support dtype " + str(dtype))


_V = TypeVar("_V", int, float)
_V = TypeVar("_V", int, float, torch.Tensor)


def dist_reduce(value: _V, reduce_op, **kwargs) -> _V:
Expand Down Expand Up @@ -2046,6 +2046,9 @@ def dist_reduce(value: _V, reduce_op, **kwargs) -> _V:
device = int_to_device(-1 if dist.get_backend() != "nccl" else torch.cuda.current_device())
value_tensor = torch.tensor(value, device=device, **kwargs)
dist.all_reduce(value_tensor, op=reduce_op)

if isinstance(value, torch.Tensor):
return value_tensor
return value_tensor.item() # type: ignore[return-value]


Expand All @@ -2062,4 +2065,4 @@ def dist_reduce_sum(value: _V, **kwargs) -> _V:
# result in an `AttributeError`.
if not is_distributed():
return value
return dist_reduce(value, dist.ReduceOp.SUM)
return dist_reduce(value, dist.ReduceOp.SUM, **kwargs)
44 changes: 43 additions & 1 deletion tests/nn/util_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
import random
from typing import NamedTuple, Any
from typing import NamedTuple, Any, Union, Callable, Dict, List

import numpy
from numpy.testing import assert_array_almost_equal, assert_almost_equal
Expand Down Expand Up @@ -1719,3 +1719,45 @@ def test_get_token_ids_from_text_field_tensors(self):
tensors = text_field.as_tensor(text_field.get_padding_lengths())
token_ids = util.get_token_ids_from_text_field_tensors(tensors)
assert (token_ids == expected_token_ids).all()

def test_dist_reduce_sum(self):

value = 23
ret_value = util.dist_reduce_sum(value)
assert ret_value == 23

value = torch.Tensor([1, 2, 3])
ret_value = util.dist_reduce_sum(value)
assert (ret_value == value).all().item()

from allennlp.common.testing.distributed_test import run_distributed_test

func_kwargs = {"value": [torch.Tensor([1, 2, 3]), torch.Tensor([4, 5, 6])]}
desired_values = torch.Tensor([5, 7, 9])

run_distributed_test(
[-1, -1],
global_distributed_func,
function=util.dist_reduce_sum,
func_kwargs=func_kwargs,
desired_values=desired_values,
)


def global_distributed_func(
global_rank: int,
world_size: int,
gpu_id: Union[int, torch.device],
function: Callable,
func_kwargs: Dict[str, List[Any]],
desired_values: torch.Tensor,
):
kwargs = {}

# Use the arguments meant for the process with rank `global_rank`.
for argname in func_kwargs:
kwargs[argname] = func_kwargs[argname][global_rank]

output = function(**kwargs)

assert (output == desired_values).all().item()