Skip to content

Commit 7ba5151

Browse files
committed
remove xformers dependency
1 parent aa03afc commit 7ba5151

File tree

1 file changed

+44
-24
lines changed

1 file changed

+44
-24
lines changed

lgm/mvdream/mv_unet.py

Lines changed: 44 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,39 @@
1111
from diffusers.configuration_utils import ConfigMixin
1212
from diffusers.models.modeling_utils import ModelMixin
1313

14-
# require xformers!
15-
import xformers
16-
import xformers.ops
17-
1814
from kiui.cam import orbit_camera
1915

16+
def memory_efficient_attention(q, k, v, attn_bias=None):
17+
"""
18+
Implements a memory-efficient attention mechanism.
19+
20+
Parameters:
21+
- q, k, v: Query, Key, Value tensors of shape (batch_size, seq_len_q, dim)
22+
- attn_bias: Optional bias tensor of shape (seq_len_k, seq_len_v)
23+
24+
Returns:
25+
- output tensor of shape (batch_size, seq_len_q, dim)
26+
"""
27+
# Calculate attention scores
28+
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(k.size(-1))
29+
if attn_bias is not None:
30+
scores += attn_bias
31+
32+
# Apply softmax to get attention weights
33+
attention_weights = torch.nn.functional.softmax(scores, dim=-1)
34+
35+
# Compute the weighted sum of values
36+
output = torch.matmul(attention_weights, v)
37+
38+
return output
39+
2040
def get_camera(
2141
num_frames, elevation=0, azimuth_start=0, azimuth_span=360, blender_coord=True, extra_view=False,
2242
):
2343
angle_gap = azimuth_span / num_frames
2444
cameras = []
2545
for azimuth in np.arange(azimuth_start, azimuth_span + azimuth_start, angle_gap):
26-
46+
2747
pose = orbit_camera(elevation, azimuth, radius=1) # [4, 4]
2848

2949
# opengl to blender
@@ -140,17 +160,17 @@ def forward(self, x):
140160
class MemoryEfficientCrossAttention(nn.Module):
141161
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
142162
def __init__(
143-
self,
144-
query_dim,
145-
context_dim=None,
146-
heads=8,
147-
dim_head=64,
163+
self,
164+
query_dim,
165+
context_dim=None,
166+
heads=8,
167+
dim_head=64,
148168
dropout=0.0,
149169
ip_dim=0,
150170
ip_weight=1,
151171
):
152172
super().__init__()
153-
173+
154174
inner_dim = dim_head * heads
155175
context_dim = default(context_dim, query_dim)
156176

@@ -199,7 +219,7 @@ def forward(self, x, context=None):
199219
)
200220

201221
# actually compute the attention, what we cannot get enough of
202-
out = xformers.ops.memory_efficient_attention(
222+
out = memory_efficient_attention(
203223
q, k, v, attn_bias=None, op=self.attention_op
204224
)
205225

@@ -213,7 +233,7 @@ def forward(self, x, context=None):
213233
(k_ip, v_ip),
214234
)
215235
# actually compute the attention, what we cannot get enough of
216-
out_ip = xformers.ops.memory_efficient_attention(
236+
out_ip = memory_efficient_attention(
217237
q, k_ip, v_ip, attn_bias=None, op=self.attention_op
218238
)
219239
out = out + self.ip_weight * out_ip
@@ -228,7 +248,7 @@ def forward(self, x, context=None):
228248

229249

230250
class BasicTransformerBlock3D(nn.Module):
231-
251+
232252
def __init__(
233253
self,
234254
dim,
@@ -259,7 +279,7 @@ def __init__(
259279
# ip only applies to cross-attention
260280
ip_dim=ip_dim,
261281
ip_weight=ip_weight,
262-
)
282+
)
263283
self.norm1 = nn.LayerNorm(dim)
264284
self.norm2 = nn.LayerNorm(dim)
265285
self.norm3 = nn.LayerNorm(dim)
@@ -311,9 +331,9 @@ def __init__(
311331
for d in range(depth)
312332
]
313333
)
314-
334+
315335
self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
316-
336+
317337

318338
def forward(self, x, context=None, num_frames=1):
319339
# note: if no context is given, cross-attention defaults to self-attention
@@ -328,7 +348,7 @@ def forward(self, x, context=None, num_frames=1):
328348
x = block(x, context=context[i], num_frames=num_frames)
329349
x = self.proj_out(x)
330350
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
331-
351+
332352
return x + x_in
333353

334354

@@ -672,7 +692,7 @@ def __init__(
672692
):
673693
super().__init__()
674694
assert context_dim is not None
675-
695+
676696
if num_heads_upsample == -1:
677697
num_heads_upsample = num_heads
678698

@@ -699,7 +719,7 @@ def __init__(
699719
"as a list/tuple (per-level) with the same length as channel_mult"
700720
)
701721
self.num_res_blocks = num_res_blocks
702-
722+
703723
if num_attention_blocks is not None:
704724
assert len(num_attention_blocks) == len(self.num_res_blocks)
705725
assert all(
@@ -848,7 +868,7 @@ def __init__(
848868
else:
849869
num_heads = ch // num_head_channels
850870
dim_head = num_head_channels
851-
871+
852872
self.middle_block = CondSequential(
853873
ResBlock(
854874
ch,
@@ -865,7 +885,7 @@ def __init__(
865885
depth=transformer_depth,
866886
ip_dim=self.ip_dim,
867887
ip_weight=self.ip_weight,
868-
),
888+
),
869889
ResBlock(
870890
ch,
871891
time_embed_dim,
@@ -983,7 +1003,7 @@ def forward(
9831003
# Add camera embeddings
9841004
if camera is not None:
9851005
emb = emb + self.camera_embed(camera)
986-
1006+
9871007
# imagedream variant
9881008
if self.ip_dim > 0:
9891009
x[(num_frames - 1) :: num_frames, :, :, :] = ip_img # place at [4, 9]
@@ -1002,4 +1022,4 @@ def forward(
10021022
if self.predict_codebook_ids:
10031023
return self.id_predictor(h)
10041024
else:
1005-
return self.out(h)
1025+
return self.out(h)

0 commit comments

Comments
 (0)