Skip to content

Commit 6c760c9

Browse files
committed
take care of the all gather for contrastive loss
1 parent 9a364dd commit 6c760c9

File tree

3 files changed

+82
-2
lines changed

3 files changed

+82
-2
lines changed

gigagan_pytorch/distributed.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import torch
2+
import torch.nn.functional as F
3+
from torch.autograd import Function
4+
import torch.distributed as dist
5+
6+
from einops import rearrange
7+
8+
# helpers
9+
10+
def exists(val):
11+
return val is not None
12+
13+
def pad_dim_to(t, length, dim = 0):
14+
pad_length = length - t.shape[dim]
15+
zero_pairs = (-dim - 1) if dim < 0 else (t.ndim - dim - 1)
16+
return F.pad(t, (*((0, 0) * zero_pairs), 0, pad_length))
17+
18+
# distributed helpers
19+
20+
def all_gather_variable_dim(t, dim = 0, sizes = None):
21+
device, world_size = t.device, dist.get_world_size()
22+
23+
if not exists(sizes):
24+
size = torch.tensor(t.shape[dim], device = device, dtype = torch.long)
25+
sizes = [torch.empty_like(size, device = device, dtype = torch.long) for i in range(world_size)]
26+
dist.all_gather(sizes, size)
27+
sizes = torch.stack(sizes)
28+
29+
max_size = sizes.amax().item()
30+
padded_t = pad_dim_to(t, max_size, dim = dim)
31+
32+
gathered_tensors = [torch.empty(padded_t.shape, device = device, dtype = padded_t.dtype) for i in range(world_size)]
33+
dist.all_gather(gathered_tensors, padded_t)
34+
35+
gathered_tensor = torch.cat(gathered_tensors, dim = dim)
36+
seq = torch.arange(max_size, device = device)
37+
38+
mask = rearrange(seq, 'j -> 1 j') < rearrange(sizes, 'i -> i 1')
39+
mask = rearrange(mask, 'i j -> (i j)')
40+
seq = torch.arange(mask.shape[-1], device = device)
41+
indices = seq[mask]
42+
43+
gathered_tensor = gathered_tensor.index_select(dim, indices)
44+
45+
return gathered_tensor, sizes
46+
47+
class AllGather(Function):
48+
@staticmethod
49+
def forward(ctx, x, dim, sizes):
50+
is_dist = dist.is_initialized() and dist.get_world_size() > 1
51+
ctx.is_dist = is_dist
52+
53+
if not is_dist:
54+
return x, None
55+
56+
x, batch_sizes = all_gather_variable_dim(x, dim = dim, sizes = sizes)
57+
ctx.batch_sizes = batch_sizes.tolist()
58+
ctx.dim = dim
59+
return x, batch_sizes
60+
61+
@staticmethod
62+
def backward(ctx, grads, _):
63+
if not ctx.is_dist:
64+
return grads, None, None
65+
66+
batch_sizes, rank = ctx.batch_sizes, dist.get_rank()
67+
grads_by_rank = grads.split(batch_sizes, dim = ctx.dim)
68+
return grads_by_rank[rank], None, None
69+
70+
all_gather = AllGather.apply

gigagan_pytorch/gigagan_pytorch.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from gigagan_pytorch.version import __version__
2727
from gigagan_pytorch.open_clip import OpenClipAdapter
2828
from gigagan_pytorch.optimizer import get_optimizer
29+
from gigagan_pytorch.distributed import all_gather
2930

3031
from tqdm import tqdm
3132

@@ -175,8 +176,11 @@ def aux_clip_loss(
175176
):
176177
assert exists(texts) ^ exists(text_embeds)
177178

179+
images, batch_sizes = all_gather(images, 0, None)
180+
178181
if exists(texts):
179182
text_embeds, _ = clip.embed_texts(texts)
183+
text_embeds, _ = all_gather(text_embeds, 0, batch_sizes)
180184

181185
return clip.contrastive_loss(images = images, text_embeds = text_embeds)
182186

@@ -1572,6 +1576,9 @@ def init_(self, m):
15721576
def resize_image_to(self, images, resolution):
15731577
return F.interpolate(images, resolution, mode = self.resize_mode)
15741578

1579+
def real_images_to_rgbs(self, images):
1580+
return [self.resize_image_to(images, resolution) for resolution in self.multiscale_input_resolutions]
1581+
15751582
@property
15761583
def total_params(self):
15771584
return sum([p.numel() for p in self.parameters()])
@@ -2160,7 +2167,7 @@ def train_discriminator_step(
21602167
real_images = real_images.to(self.device)
21612168
real_images.requires_grad_()
21622169

2163-
real_images_rgbs = [self.resize_image_to(real_images, resolution) for resolution in self.unwrapped_D.multiscale_input_resolutions]
2170+
real_images_rgbs = self.unwrapped_D.real_images_to_rgbs(real_images)
21642171

21652172
# diff augment real images
21662173

@@ -2331,8 +2338,11 @@ def train_discriminator_step(
23312338
calc_aux_loss = False
23322339
)
23332340

2341+
real_images_rgbs = self.D.real_images_to_rgbs(real_images)
2342+
23342343
real_logits, *_ = self.D(
23352344
real_images,
2345+
real_images_rgbs,
23362346
texts = texts,
23372347
return_multiscale_outputs = False,
23382348
calc_aux_loss = False

gigagan_pytorch/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '0.2.16'
1+
__version__ = '0.2.17'

0 commit comments

Comments
 (0)