11
11
from diffusers .configuration_utils import ConfigMixin
12
12
from diffusers .models .modeling_utils import ModelMixin
13
13
14
- # require xformers!
15
- import xformers
16
- import xformers .ops
17
-
18
14
from kiui .cam import orbit_camera
19
15
16
+ def memory_efficient_attention (q , k , v , attn_bias = None ):
17
+ """
18
+ Implements a memory-efficient attention mechanism.
19
+
20
+ Parameters:
21
+ - q, k, v: Query, Key, Value tensors of shape (batch_size, seq_len_q, dim)
22
+ - attn_bias: Optional bias tensor of shape (seq_len_k, seq_len_v)
23
+
24
+ Returns:
25
+ - output tensor of shape (batch_size, seq_len_q, dim)
26
+ """
27
+ # Calculate attention scores
28
+ scores = torch .matmul (q , k .transpose (- 2 , - 1 )) / math .sqrt (k .size (- 1 ))
29
+ if attn_bias is not None :
30
+ scores += attn_bias
31
+
32
+ # Apply softmax to get attention weights
33
+ attention_weights = torch .nn .functional .softmax (scores , dim = - 1 )
34
+
35
+ # Compute the weighted sum of values
36
+ output = torch .matmul (attention_weights , v )
37
+
38
+ return output
39
+
20
40
def get_camera (
21
41
num_frames , elevation = 0 , azimuth_start = 0 , azimuth_span = 360 , blender_coord = True , extra_view = False ,
22
42
):
23
43
angle_gap = azimuth_span / num_frames
24
44
cameras = []
25
45
for azimuth in np .arange (azimuth_start , azimuth_span + azimuth_start , angle_gap ):
26
-
46
+
27
47
pose = orbit_camera (elevation , azimuth , radius = 1 ) # [4, 4]
28
48
29
49
# opengl to blender
@@ -140,17 +160,17 @@ def forward(self, x):
140
160
class MemoryEfficientCrossAttention (nn .Module ):
141
161
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
142
162
def __init__ (
143
- self ,
144
- query_dim ,
145
- context_dim = None ,
146
- heads = 8 ,
147
- dim_head = 64 ,
163
+ self ,
164
+ query_dim ,
165
+ context_dim = None ,
166
+ heads = 8 ,
167
+ dim_head = 64 ,
148
168
dropout = 0.0 ,
149
169
ip_dim = 0 ,
150
170
ip_weight = 1 ,
151
171
):
152
172
super ().__init__ ()
153
-
173
+
154
174
inner_dim = dim_head * heads
155
175
context_dim = default (context_dim , query_dim )
156
176
@@ -199,7 +219,7 @@ def forward(self, x, context=None):
199
219
)
200
220
201
221
# actually compute the attention, what we cannot get enough of
202
- out = xformers . ops . memory_efficient_attention (
222
+ out = memory_efficient_attention (
203
223
q , k , v , attn_bias = None , op = self .attention_op
204
224
)
205
225
@@ -213,7 +233,7 @@ def forward(self, x, context=None):
213
233
(k_ip , v_ip ),
214
234
)
215
235
# actually compute the attention, what we cannot get enough of
216
- out_ip = xformers . ops . memory_efficient_attention (
236
+ out_ip = memory_efficient_attention (
217
237
q , k_ip , v_ip , attn_bias = None , op = self .attention_op
218
238
)
219
239
out = out + self .ip_weight * out_ip
@@ -228,7 +248,7 @@ def forward(self, x, context=None):
228
248
229
249
230
250
class BasicTransformerBlock3D (nn .Module ):
231
-
251
+
232
252
def __init__ (
233
253
self ,
234
254
dim ,
@@ -259,7 +279,7 @@ def __init__(
259
279
# ip only applies to cross-attention
260
280
ip_dim = ip_dim ,
261
281
ip_weight = ip_weight ,
262
- )
282
+ )
263
283
self .norm1 = nn .LayerNorm (dim )
264
284
self .norm2 = nn .LayerNorm (dim )
265
285
self .norm3 = nn .LayerNorm (dim )
@@ -311,9 +331,9 @@ def __init__(
311
331
for d in range (depth )
312
332
]
313
333
)
314
-
334
+
315
335
self .proj_out = zero_module (nn .Linear (in_channels , inner_dim ))
316
-
336
+
317
337
318
338
def forward (self , x , context = None , num_frames = 1 ):
319
339
# note: if no context is given, cross-attention defaults to self-attention
@@ -328,7 +348,7 @@ def forward(self, x, context=None, num_frames=1):
328
348
x = block (x , context = context [i ], num_frames = num_frames )
329
349
x = self .proj_out (x )
330
350
x = rearrange (x , "b (h w) c -> b c h w" , h = h , w = w ).contiguous ()
331
-
351
+
332
352
return x + x_in
333
353
334
354
@@ -672,7 +692,7 @@ def __init__(
672
692
):
673
693
super ().__init__ ()
674
694
assert context_dim is not None
675
-
695
+
676
696
if num_heads_upsample == - 1 :
677
697
num_heads_upsample = num_heads
678
698
@@ -699,7 +719,7 @@ def __init__(
699
719
"as a list/tuple (per-level) with the same length as channel_mult"
700
720
)
701
721
self .num_res_blocks = num_res_blocks
702
-
722
+
703
723
if num_attention_blocks is not None :
704
724
assert len (num_attention_blocks ) == len (self .num_res_blocks )
705
725
assert all (
@@ -848,7 +868,7 @@ def __init__(
848
868
else :
849
869
num_heads = ch // num_head_channels
850
870
dim_head = num_head_channels
851
-
871
+
852
872
self .middle_block = CondSequential (
853
873
ResBlock (
854
874
ch ,
@@ -865,7 +885,7 @@ def __init__(
865
885
depth = transformer_depth ,
866
886
ip_dim = self .ip_dim ,
867
887
ip_weight = self .ip_weight ,
868
- ),
888
+ ),
869
889
ResBlock (
870
890
ch ,
871
891
time_embed_dim ,
@@ -983,7 +1003,7 @@ def forward(
983
1003
# Add camera embeddings
984
1004
if camera is not None :
985
1005
emb = emb + self .camera_embed (camera )
986
-
1006
+
987
1007
# imagedream variant
988
1008
if self .ip_dim > 0 :
989
1009
x [(num_frames - 1 ) :: num_frames , :, :, :] = ip_img # place at [4, 9]
@@ -1002,4 +1022,4 @@ def forward(
1002
1022
if self .predict_codebook_ids :
1003
1023
return self .id_predictor (h )
1004
1024
else :
1005
- return self .out (h )
1025
+ return self .out (h )
0 commit comments