-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathadp.py
2652 lines (2283 loc) · 103 KB
/
adp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copied and modified from https://github.com/archinetai/audio-diffusion-pytorch/blob/v0.0.94/audio_diffusion_pytorch/modules.py under MIT License
# License can be found in LICENSES/LICENSE_ADP.txt
import math
from inspect import isfunction
from math import ceil, floor, log, pi, log2
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
from packaging import version
import torch
import torch.nn as nn
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange
from einops_exts import rearrange_many
from torch import Tensor, einsum
from torch.backends.cuda import sdp_kernel
from torch.nn import functional as F
from dac.nn.layers import Snake1d
################################################ Utils ################################################
# 定义 ConditionedSequential 类,用于按顺序执行多个模块,并可选择性地使用映射张量
class ConditionedSequential(nn.Module):
"""
ConditionedSequential 类用于按顺序执行多个模块,并可选择性地使用映射张量。
初始化参数:
- *modules: 可变数量的模块,将按顺序执行。
"""
def __init__(self, *modules):
super().__init__()
# 将传入的模块列表存储为 ModuleList
self.module_list = nn.ModuleList(*modules)
def forward(self, x: Tensor, mapping: Optional[Tensor] = None):
"""
前向传播方法,按顺序执行每个模块,并可选择性地使用映射张量。
参数:
- x (Tensor): 输入张量。
- mapping (Optional[Tensor], 可选): 可选的映射张量,用于每个模块。
返回:
- Tensor: 最后一个模块的输出。
"""
for module in self.module_list:
# 按顺序执行每个模块,并将输出传递给下一个模块
x = module(x, mapping)
return x
# 定义类型变量 T,用于泛型编程
T = TypeVar("T")
# 定义 default 函数,返回可选值或默认值
def default(val: Optional[T], d: Union[Callable[..., T], T]) -> T:
"""
返回可选值或默认值。
参数:
- val (Optional[T]): 需要检查的可选值。
- d (Union[Callable[..., T], T]): 默认值。如果 d 是可调用对象,则调用它以获取默认值。
返回:
- T: 如果 val 存在,则返回 val;否则返回 d 的值或调用 d() 的返回值。
"""
if exists(val):
return val
return d() if isfunction(d) else d
# 定义 exists 函数,检查一个值是否存在(不为 None)
def exists(val: Optional[T]) -> T:
"""
检查一个值是否存在(不为 None)。
参数:
- val (Optional[T]): 需要检查的值。
返回:
- T: 如果 val 不为 None,则返回 val。
"""
return val is not None
# 定义 closest_power_2 函数,找到最接近输入值的2的幂
def closest_power_2(x: float) -> int:
"""
找到最接近输入值的2的幂。
参数:
- x (float): 输入值。
返回:
- int: 最接近 x 的2的幂。
"""
# 计算 x 的以2为底的对数
exponent = log2(x)
# 定义距离函数,计算 x 与 2^z 的绝对差值
distance_fn = lambda z: abs(x - 2 ** z) # noqa
# 找到最接近的整数指数
exponent_closest = min((floor(exponent), ceil(exponent)), key=distance_fn)
return 2 ** int(exponent_closest)
# 定义 group_dict_by_prefix 函数,根据键是否以指定前缀开头,将字典分成两部分
def group_dict_by_prefix(prefix: str, d: Dict) -> Tuple[Dict, Dict]:
"""
根据键是否以指定前缀开头,将字典分成两部分。
参数:
- prefix (str): 前缀字符串,用于判断键是否以此开头。
- d (Dict): 输入字典。
返回:
- Tuple[Dict, Dict]: 返回一个元组,包含两个字典。
第一个字典包含不以 prefix 开头的键值对,
第二个字典包含以 prefix 开头的键值对。
"""
return_dicts: Tuple[Dict, Dict] = ({}, {})
for key in d.keys():
no_prefix = int(not key.startswith(prefix))
return_dicts[no_prefix][key] = d[key]
return return_dicts
# 定义 groupby 函数,根据键是否以指定前缀开头,将字典分成两部分,并可选择是否保留前缀
def groupby(prefix: str, d: Dict, keep_prefix: bool = False) -> Tuple[Dict, Dict]:
"""
根据键是否以指定前缀开头,将字典分成两部分,并可选择是否保留前缀。
参数:
- prefix (str): 前缀字符串,用于判断键是否以此开头。
- d (Dict): 输入字典。
- keep_prefix (bool, 可选): 是否保留前缀,默认为 False。
返回:
- Tuple[Dict, Dict]: 返回一个元组,包含两个字典。
如果 keep_prefix 为 False,第一个字典的键不包含 prefix;
如果 keep_prefix 为 True,第一个字典的键保留 prefix。
第二个字典始终包含以 prefix 开头的键值对。
"""
# 使用 group_dict_by_prefix 分割字典
kwargs_with_prefix, kwargs = group_dict_by_prefix(prefix, d)
if keep_prefix:
return kwargs_with_prefix, kwargs
kwargs_no_prefix = {k[len(prefix) :]: v for k, v in kwargs_with_prefix.items()}
return kwargs_no_prefix, kwargs
################################################ Convolutional Blocks ################################################
import typing as tp
# Copied from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/modules/conv.py under MIT License
# License available in LICENSES/LICENSE_META.txt
def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int,
padding_total: int = 0) -> int:
"""See `pad_for_conv1d`."""
"""
计算一维卷积所需的额外填充大小,以确保最后一个窗口是完整的。
参数:
- x (torch.Tensor): 输入张量,形状为 (batch_size, channels, length)。
- kernel_size (int): 卷积核大小。
- stride (int): 步幅。
- padding_total (int, 可选): 总填充大小,默认为 0。
返回:
- int: 所需的额外填充大小。
"""
# 获取输入张量的长度
length = x.shape[-1]
# 计算卷积后的帧数
n_frames = (length - kernel_size + padding_total) / stride + 1
# 计算理想长度,以确保最后一个窗口是完整的
ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
# 返回所需的总填充大小
return ideal_length - length
def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0):
"""Pad for a convolution to make sure that the last window is full.
Extra padding is added at the end. This is required to ensure that we can rebuild
an output of the same length, as otherwise, even with padding, some time steps
might get removed.
For instance, with total padding = 4, kernel size = 4, stride = 2:
0 0 1 2 3 4 5 0 0 # (0s are padding)
1 2 3 # (output frames of a convolution, last 0 is never used)
0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
1 2 3 4 # once you removed padding, we are missing one time step !
"""
"""
为一维卷积填充输入张量,以确保最后一个窗口是完整的。
额外的填充将添加在末尾。这是为了确保我们可以重建相同长度的输出,
否则,即使有填充,某些时间步也可能被移除。
参数:
- x (torch.Tensor): 输入张量,形状为 (batch_size, channels, length)。
- kernel_size (int): 卷积核大小。
- stride (int): 步幅。
- padding_total (int, 可选): 总填充大小,默认为 0。
返回:
- torch.Tensor: 填充后的张量。
"""
# 计算额外填充大小
extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
# 在末尾添加额外填充
return F.pad(x, (0, extra_padding))
def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'constant', value: float = 0.):
"""Tiny wrapper around F.pad, just to allow for reflect padding on small input.
If this is the case, we insert extra 0 padding to the right before the reflection happen.
"""
"""
一维填充的简单封装,仅允许对小型输入进行反射填充。
如果输入长度小于最大填充大小,则在反射填充之前添加额外的零填充。
参数:
- x (torch.Tensor): 输入张量。
- paddings (Tuple[int, int]): 左右填充大小,例如 (padding_left, padding_right)。
- mode (str, 可选): 填充模式,默认为 'constant'。
- value (float, 可选): 填充值,默认为 0。
返回:
- torch.Tensor: 填充后的张量。
"""
# 获取输入张量的长度
length = x.shape[-1]
# 获取左右填充大小
padding_left, padding_right = paddings
# 确保填充大小为非负数
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
if mode == 'reflect':
# 获取最大填充大小
max_pad = max(padding_left, padding_right)
extra_pad = 0
if length <= max_pad:
# 计算额外的零填充大小
extra_pad = max_pad - length + 1
# 在末尾添加额外的零填充
x = F.pad(x, (0, extra_pad))
# 进行反射填充
padded = F.pad(x, paddings, mode, value)
# 计算截断位置
end = padded.shape[-1] - extra_pad
# 返回截断后的填充张量
return padded[..., :end]
else:
# 进行其他模式的填充
return F.pad(x, paddings, mode, value)
def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
"""Remove padding from x, handling properly zero padding. Only for 1d!"""
"""
从一维张量中移除填充,处理零填充。
参数:
- x (torch.Tensor): 输入张量。
- paddings (Tuple[int, int]): 左右填充大小,例如 (padding_left, padding_right)。
返回:
- torch.Tensor: 移除填充后的张量。
"""
# 获取左右填充大小
padding_left, padding_right = paddings
# 确保填充大小为非负数
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
# 确保填充大小不超过张量长度
assert (padding_left + padding_right) <= x.shape[-1]
# 计算移除填充后的结束位置
end = x.shape[-1] - padding_right
# 返回移除填充后的张量
return x[..., padding_left: end]
class Conv1d(nn.Conv1d):
"""
Conv1d 类继承自 torch.nn.Conv1d,并重写了前向传播方法。
该类支持因果卷积和非对称填充,以确保卷积输出长度与输入长度一致。
初始化参数:
- *args: 传递给 torch.nn.Conv1d 的位置参数。
- **kwargs: 传递给 torch.nn.Conv1d 的关键字参数。
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, x: Tensor, causal=False) -> Tensor:
"""
前向传播方法,执行一维卷积操作,并进行必要的填充。
参数:
- x (Tensor): 输入张量,形状为 (batch_size, in_channels, length)。
- causal (bool, 可选): 是否进行因果卷积,默认为 False。
返回:
- Tensor: 卷积后的输出张量。
"""
# 获取卷积核大小
kernel_size = self.kernel_size[0]
# 获取步幅
stride = self.stride[0]
# 获取膨胀率
dilation = self.dilation[0]
# 计算有效卷积核大小,考虑膨胀率
kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations
# 计算总填充大小
padding_total = kernel_size - stride
# 计算额外填充大小,以确保最后一个窗口是完整的
extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
if causal:
# Left padding for causal
# 如果是因果卷积,则在左侧添加填充
x = pad1d(x, (padding_total, extra_padding))
else:
# Asymmetric padding required for odd strides
# 否则,进行非对称填充(适用于奇数步幅)
padding_right = padding_total // 2
padding_left = padding_total - padding_right
x = pad1d(x, (padding_left, padding_right + extra_padding))
return super().forward(x)
class ConvTranspose1d(nn.ConvTranspose1d):
"""
ConvTranspose1d 类继承自 torch.nn.ConvTranspose1d,并重写了前向传播方法。
该类支持因果卷积和非对称填充,以确保转置卷积输出长度与输入长度一致。
初始化参数:
- *args: 传递给 torch.nn.ConvTranspose1d 的位置参数。
- **kwargs: 传递给 torch.nn.ConvTranspose1d 的关键字参数。
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, x: Tensor, causal=False) -> Tensor:
"""
前向传播方法,执行一维转置卷积操作,并进行必要的裁剪。
参数:
- x (Tensor): 输入张量。
- causal (bool, 可选): 是否进行因果卷积,默认为 False。
返回:
- Tensor: 转置卷积后的输出张量。
"""
# 获取卷积核大小
kernel_size = self.kernel_size[0]
# 获取步幅
stride = self.stride[0]
# 计算总填充大小
padding_total = kernel_size - stride
# 调用父类的前向传播方法进行转置卷积操作
y = super().forward(x)
# We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
# removed at the very end, when keeping only the right length for the output,
# as removing it here would require also passing the length at the matching layer
# 仅裁剪固定填充。`pad_for_conv1d` 中的额外填充将在最后移除,
# 以确保输出长度正确。
# 如果在这里移除,需要在编码器中传递相应的长度信息。
# in the encoder.
if causal:
# 计算右侧填充大小(向上取整)
padding_right = ceil(padding_total)
# 计算左侧填充大小
padding_left = padding_total - padding_right
# 移除填充
y = unpad1d(y, (padding_left, padding_right))
else:
# Asymmetric padding required for odd strides
# 如果不是因果卷积,则进行非对称填充(适用于奇数步幅)
padding_right = padding_total // 2
padding_left = padding_total - padding_right
# 移除填充
y = unpad1d(y, (padding_left, padding_right))
# 返回裁剪后的输出张量
return y
def Downsample1d(
in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2
) -> nn.Module:
"""
创建一维下采样层。
参数:
- in_channels (int): 输入通道数。
- out_channels (int): 输出通道数。
- factor (int): 下采样因子。
- kernel_multiplier (int, 可选): 卷积核大小的乘数,默认为2。
返回:
- nn.Module: 一维下采样层。
断言:
- kernel_multiplier 必须为偶数,否则抛出错误。
"""
assert kernel_multiplier % 2 == 0, "Kernel multiplier must be even"
return Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=factor * kernel_multiplier + 1,
stride=factor
)
def Upsample1d(
in_channels: int, out_channels: int, factor: int, use_nearest: bool = False
) -> nn.Module:
"""
创建一维上采样层。
参数:
- in_channels (int): 输入通道数。
- out_channels (int): 输出通道数。
- factor (int): 上采样因子。
- use_nearest (bool, 可选): 是否使用最近邻插值,默认为 False。
返回:
- nn.Module: 一维上采样层。
如果 factor 为 1,则使用 3x1 卷积层进行上采样。
如果 use_nearest 为 True,则使用最近邻插值和卷积层进行上采样。
否则,使用转置卷积层进行上采样。
"""
if factor == 1:
return Conv1d(
in_channels=in_channels, out_channels=out_channels, kernel_size=3
)
if use_nearest:
return nn.Sequential(
# 最近邻插值上采样
nn.Upsample(scale_factor=factor, mode="nearest"),
Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3
),
)
else:
return ConvTranspose1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=factor * 2,
stride=factor
)
class ConvBlock1d(nn.Module):
"""
ConvBlock1d 类实现了一维卷积块,包含组归一化、激活函数和卷积层。
初始化参数:
- in_channels (int): 输入通道数。
- out_channels (int): 输出通道数。
- kernel_size (int, 可选): 卷积核大小,默认为3。
- stride (int, 可选): 步幅,默认为1。
- dilation (int, 可选): 膨胀率,默认为1。
- num_groups (int, 可选): 组归一化的组数,默认为8。
- use_norm (bool, 可选): 是否使用组归一化,默认为 True。
- use_snake (bool, 可选): 是否使用 Snake 激活函数,默认为 False。
"""
def __init__(
self,
in_channels: int,
out_channels: int,
*,
kernel_size: int = 3,
stride: int = 1,
dilation: int = 1,
num_groups: int = 8,
use_norm: bool = True,
use_snake: bool = False
) -> None:
super().__init__()
# 定义组归一化层,如果 use_norm 为 False,则使用恒等映射
self.groupnorm = (
nn.GroupNorm(num_groups=num_groups, num_channels=in_channels)
if use_norm
else nn.Identity()
)
# 定义激活函数,如果 use_snake 为 True,则使用 Snake 激活函数;否则,使用 SiLU 激活函数
if use_snake:
self.activation = Snake1d(in_channels)
else:
self.activation = nn.SiLU()
# 定义卷积层
self.project = Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
)
def forward(
self, x: Tensor, scale_shift: Optional[Tuple[Tensor, Tensor]] = None, causal=False
) -> Tensor:
"""
前向传播方法,执行组归一化、激活函数和卷积操作。
参数:
- x (Tensor): 输入张量。
- scale_shift (Optional[Tuple[Tensor, Tensor]], 可选): 可选的缩放和偏移量。
- causal (bool, 可选): 是否进行因果卷积,默认为 False。
返回:
- Tensor: 卷积后的输出张量。
"""
# 应用组归一化
x = self.groupnorm(x)
if exists(scale_shift):
scale, shift = scale_shift
x = x * (scale + 1) + shift
x = self.activation(x)
return self.project(x, causal=causal)
class MappingToScaleShift(nn.Module):
"""
MappingToScaleShift 类用于将映射张量转换为缩放和偏移量。
初始化参数:
- features (int): 映射张量的特征维度。
- channels (int): 输出通道数,用于生成缩放和偏移量。
"""
def __init__(
self,
features: int,
channels: int,
):
super().__init__()
# 定义一个序列模块,用于将映射张量转换为缩放和偏移量
self.to_scale_shift = nn.Sequential(
nn.SiLU(),
# 全连接层,将特征维度映射到 2 倍的通道数
nn.Linear(in_features=features, out_features=channels * 2),
)
def forward(self, mapping: Tensor) -> Tuple[Tensor, Tensor]:
"""
前向传播方法,将映射张量转换为缩放和偏移量。
参数:
- mapping (Tensor): 输入的映射张量,形状为 (batch_size, features)。
返回:
- Tuple[Tensor, Tensor]: 返回一个元组,包含缩放量和偏移量,形状均为 (batch_size, channels, 1)。
"""
scale_shift = self.to_scale_shift(mapping)
scale_shift = rearrange(scale_shift, "b c -> b c 1")
scale, shift = scale_shift.chunk(2, dim=1)
return scale, shift
class ResnetBlock1d(nn.Module):
"""
ResnetBlock1d 类实现了一维残差块,包含两个卷积块和一个残差连接。
初始化参数:
- in_channels (int): 输入通道数。
- out_channels (int): 输出通道数。
- kernel_size (int, 可选): 卷积核大小,默认为3。
- stride (int, 可选): 步幅,默认为1。
- dilation (int, 可选): 膨胀率,默认为1。
- use_norm (bool, 可选): 是否使用组归一化,默认为 True。
- use_snake (bool, 可选): 是否使用 Snake 激活函数,默认为 False。
- num_groups (int, 可选): 组归一化的组数,默认为8。
- context_mapping_features (Optional[int], 可选): 上下文映射特征维度。如果提供,则使用映射张量生成缩放和偏移量。
"""
def __init__(
self,
in_channels: int,
out_channels: int,
*,
kernel_size: int = 3,
stride: int = 1,
dilation: int = 1,
use_norm: bool = True,
use_snake: bool = False,
num_groups: int = 8,
context_mapping_features: Optional[int] = None,
) -> None:
super().__init__()
self.use_mapping = exists(context_mapping_features)
# 定义第一个卷积块
self.block1 = ConvBlock1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
use_norm=use_norm,
num_groups=num_groups,
use_snake=use_snake
)
if self.use_mapping:
# 确保提供了上下文映射特征维度
assert exists(context_mapping_features)
# 定义映射到缩放和偏移量的模块
self.to_scale_shift = MappingToScaleShift(
# 上下文映射特征维度、输出通道数
features=context_mapping_features, channels=out_channels
)
# 定义第二个卷积块
self.block2 = ConvBlock1d(
in_channels=out_channels,
out_channels=out_channels,
use_norm=use_norm,
num_groups=num_groups,
use_snake=use_snake
)
# 定义输出映射层,如果输入通道数与输出通道数不同,则使用 1x1 卷积层;否则,使用恒等映射
self.to_out = (
Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1)
if in_channels != out_channels
else nn.Identity()
)
def forward(self, x: Tensor, mapping: Optional[Tensor] = None, causal=False) -> Tensor:
"""
前向传播方法,执行一维残差块操作。
参数:
- x (Tensor): 输入张量。
- mapping (Optional[Tensor], 可选): 可选的映射张量,用于生成缩放和偏移量。
- causal (bool, 可选): 是否进行因果卷积,默认为 False。
返回:
- Tensor: 残差块的输出张量。
"""
assert_message = "context mapping required if context_mapping_features > 0"
# 确保映射张量的存在性
assert not (self.use_mapping ^ exists(mapping)), assert_message
# 应用第一个卷积块
h = self.block1(x, causal=causal)
# 初始化缩放和偏移量
scale_shift = None
if self.use_mapping:
# 使用映射张量生成缩放和偏移量
scale_shift = self.to_scale_shift(mapping)
# 应用第二个卷积块
h = self.block2(h, scale_shift=scale_shift, causal=causal)
# 应用残差连接并返回结果
return h + self.to_out(x)
class Patcher(nn.Module):
"""
Patcher 类用于将一维输入张量分割成多个块(patch),并对每个块应用残差块。
初始化参数:
- in_channels (int): 输入通道数。
- out_channels (int): 输出通道数。
- patch_size (int): 每个块的大小。
- context_mapping_features (Optional[int], 可选): 上下文映射特征维度。如果提供,则使用映射张量生成缩放和偏移量。
- use_snake (bool, 可选): 是否使用 Snake 激活函数,默认为 False。
"""
def __init__(
self,
in_channels: int,
out_channels: int,
patch_size: int,
context_mapping_features: Optional[int] = None,
use_snake: bool = False,
):
super().__init__()
assert_message = f"out_channels must be divisible by patch_size ({patch_size})"
# 确保输出通道数可以被块大小整除
assert out_channels % patch_size == 0, assert_message
# 存储块大小
self.patch_size = patch_size
# 定义残差块,输出通道数除以块大小
self.block = ResnetBlock1d(
# 输入通道数
in_channels=in_channels,
# 输出通道数除以块大小
out_channels=out_channels // patch_size,
# 组归一化的组数
num_groups=1,
# 上下文映射特征维度
context_mapping_features=context_mapping_features,
# 是否使用 Snake 激活函数
use_snake=use_snake
)
def forward(self, x: Tensor, mapping: Optional[Tensor] = None, causal=False) -> Tensor:
"""
前向传播方法,将输入张量分割成块并应用残差块。
参数:
- x (Tensor): 输入张量,形状为 (batch_size, in_channels, length)。
- mapping (Optional[Tensor], 可选): 可选的映射张量,用于生成缩放和偏移量。
- causal (bool, 可选): 是否进行因果卷积,默认为 False。
返回:
- Tensor: 输出张量,形状为 (batch_size, out_channels, length // patch_size)。
"""
# 应用残差块
x = self.block(x, mapping, causal=causal)
# 将张量重塑为 (batch_size, out_channels // patch_size, length, patch_size)
x = rearrange(x, "b c (l p) -> b (c p) l", p=self.patch_size)
# 返回重塑后的张量
return x
class Unpatcher(nn.Module):
"""
Unpatcher 类用于将一维输入张量中的块(patch)重新组合成原始张量,并对每个块应用残差块。
初始化参数:
- in_channels (int): 输入通道数。
- out_channels (int): 输出通道数。
- patch_size (int): 每个块的大小。
- context_mapping_features (Optional[int], 可选): 上下文映射特征维度。如果提供,则使用映射张量生成缩放和偏移量。
- use_snake (bool, 可选): 是否使用 Snake 激活函数,默认为 False。
"""
def __init__(
self,
in_channels: int,
out_channels: int,
patch_size: int,
context_mapping_features: Optional[int] = None,
use_snake: bool = False
):
super().__init__()
assert_message = f"in_channels must be divisible by patch_size ({patch_size})"
# 确保输入通道数可以被块大小整除
assert in_channels % patch_size == 0, assert_message
# 存储块大小
self.patch_size = patch_size
# 定义残差块,输入通道数除以块大小
self.block = ResnetBlock1d(
# 输入通道数除以块大小
in_channels=in_channels // patch_size,
# 输出通道数
out_channels=out_channels,
# 组归一化的组数
num_groups=1,
# 上下文映射特征维度
context_mapping_features=context_mapping_features,
# 是否使用 Snake 激活函数
use_snake=use_snake
)
def forward(self, x: Tensor, mapping: Optional[Tensor] = None, causal=False) -> Tensor:
"""
前向传播方法,将输入张量中的块重新组合成原始张量,并对每个块应用残差块。
参数:
- x (Tensor): 输入张量,形状为 (batch_size, in_channels, length // patch_size)。
- mapping (Optional[Tensor], 可选): 可选的映射张量,用于生成缩放和偏移量。
- causal (bool, 可选): 是否进行因果卷积,默认为 False。
返回:
- Tensor: 输出张量,形状为 (batch_size, out_channels, length)。
"""
# 将张量重塑为 (batch_size, in_channels // patch_size, length, patch_size)
x = rearrange(x, " b (c p) l -> b c (l p) ", p=self.patch_size)
# 应用残差块
x = self.block(x, mapping, causal=causal)
return x
################################################ Attention Components ################################################
# 定义前馈网络(FeedForward)层
def FeedForward(features: int, multiplier: int) -> nn.Module:
"""
创建前馈网络层,包含两个线性层和一个 GELU 激活函数。
参数:
- features (int): 输入特征的维度。
- multiplier (int): 中间层特征的维度乘数。
返回:
- nn.Module: 前馈网络层。
"""
mid_features = features * multiplier
return nn.Sequential(
nn.Linear(in_features=features, out_features=mid_features),
nn.GELU(),
nn.Linear(in_features=mid_features, out_features=features),
)
# 定义添加掩码的函数,用于掩码注意力分数
def add_mask(sim: Tensor, mask: Tensor) -> Tensor:
"""
在注意力分数矩阵中添加掩码。
参数:
- sim (Tensor): 注意力分数矩阵,形状为 (batch_size, n, m)。
- mask (Tensor): 掩码矩阵,形状为 (batch_size, n, m) 或 (n, m)。
返回:
- Tensor: 添加掩码后的注意力分数矩阵。
"""
# 获取批次大小和掩码维度
b, ndim = sim.shape[0], mask.ndim
if ndim == 3:
# 如果掩码是 (batch_size, n, m),则重塑为 (batch_size, 1, n, m)
mask = rearrange(mask, "b n m -> b 1 n m")
if ndim == 2:
# 如果掩码是 (n, m),则重复批次维度
mask = repeat(mask, "n m -> b 1 n m", b=b)
# 获取数据类型允许的最小值
max_neg_value = -torch.finfo(sim.dtype).max
# 将掩码为 False 的位置设置为最小值
sim = sim.masked_fill(~mask, max_neg_value)
# 返回添加掩码后的注意力分数矩阵
return sim
# 定义因果掩码的函数,用于因果注意力
def causal_mask(q: Tensor, k: Tensor) -> Tensor:
"""
生成因果掩码,阻止模型关注未来的时间步。
参数:
- q (Tensor): 查询张量,形状为 (batch_size, num_heads, seq_len, head_features)。
- k (Tensor): 键张量,形状为 (batch_size, num_heads, seq_len, head_features)。
返回:
- Tensor: 因果掩码矩阵,形状为 (batch_size, seq_len, seq_len)。
"""
# 获取批次大小、序列长度和设备
b, i, j, device = q.shape[0], q.shape[-2], k.shape[-2], q.device
# 生成一个上三角矩阵,值为 False,表示不允许模型关注未来的时间步
mask = ~torch.ones((i, j), dtype=torch.bool, device=device).triu(j - i + 1)
# 重复批次维度
mask = repeat(mask, "n m -> b n m", b=b)
return mask
class AttentionBase(nn.Module):
"""
AttentionBase 类实现了多头注意力机制的基础功能。
初始化参数:
- features (int): 输入特征的维度。
- head_features (int): 每个注意力头的特征维度。
- num_heads (int): 注意力头的数量。
- out_features (Optional[int], 可选): 输出特征的维度。如果未提供,则默认为输入特征的维度。
"""
def __init__(
self,
features: int,
*,
head_features: int,
num_heads: int,
out_features: Optional[int] = None,
):
super().__init__()
# 缩放因子,用于缩放注意力分数
self.scale = head_features**-0.5
# 注意力头的数量
self.num_heads = num_heads
# 中间特征的维度
mid_features = head_features * num_heads
# 如果未提供输出特征的维度,则默认为输入特征的维度
out_features = default(out_features, features)
# 定义输出线性层
self.to_out = nn.Linear(
in_features=mid_features, out_features=out_features
)
# 检查是否可以使用 Flash Attention
self.use_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse('2.0.0')
if not self.use_flash:
# 如果不使用 Flash Attention,则返回
return
# 获取 CUDA 设备属性
device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
if device_properties.major == 8 and device_properties.minor == 0:
# Use flash attention for A100 GPUs
# 如果是 A100 GPU,则使用 Flash Attention
self.sdp_kernel_config = (True, False, False)
else:
# Don't use flash attention for other GPUs
# 否则,不使用 Flash Attention
self.sdp_kernel_config = (False, True, True)
def forward(
self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None, is_causal: bool = False
) -> Tensor:
"""
前向传播方法,执行多头注意力机制。
参数:
- q (Tensor): 查询张量。
- k (Tensor): 键张量。
- v (Tensor): 值张量。
- mask (Optional[Tensor], 可选): 可选的掩码张量。
- is_causal (bool, 可选): 是否进行因果注意力,默认为 False。
返回:
- Tensor: 注意力机制的输出张量。
"""
# 分割多头
# Split heads
q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=self.num_heads)
if not self.use_flash:
if is_causal and not mask:
# Mask out future tokens for causal attention
# 如果进行因果注意力且未提供掩码,则生成因果掩码
mask = causal_mask(q, k)
# Compute similarity matrix and add eventual mask
# 计算相似度矩阵,并添加掩码
sim = einsum("... n d, ... m d -> ... n m", q, k) * self.scale
sim = add_mask(sim, mask) if exists(mask) else sim
# Get attention matrix with softmax
# 计算注意力权重
attn = sim.softmax(dim=-1, dtype=torch.float32)
# Compute values
# 计算输出
out = einsum("... n m, ... m d -> ... n d", attn, v)
else:
with sdp_kernel(*self.sdp_kernel_config):
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, is_causal=is_causal)
# 重塑输出张量
out = rearrange(out, "b h n d -> b n (h d)")
# 应用输出线性层并返回结果
return self.to_out(out)
# 定义 Attention 类,实现多头注意力机制
class Attention(nn.Module):
"""
Attention 类实现了多头注意力机制,支持上下文特征和因果掩码。
初始化参数:
- features (int): 输入特征的维度。
- head_features (int): 每个注意力头的特征维度。
- num_heads (int): 注意力头的数量。
- out_features (Optional[int], 可选): 输出特征的维度。如果未提供,则默认为输入特征的维度。
- context_features (Optional[int], 可选): 上下文特征的维度。如果未提供,则默认为输入特征的维度。
- causal (bool, 可选): 是否进行因果注意力,默认为 False。
"""
def __init__(
self,
features: int,
*,
head_features: int,
num_heads: int,
out_features: Optional[int] = None,
context_features: Optional[int] = None,
causal: bool = False,
):
super().__init__()
# 存储上下文特征的维度
self.context_features = context_features
# 存储是否进行因果注意力
self.causal = causal
# 中间特征的维度
mid_features = head_features * num_heads
# 如果未提供上下文特征的维度,则默认为输入特征的维度
context_features = default(context_features, features)
# 定义输入特征的层归一化
self.norm = nn.LayerNorm(features)
# 定义上下文特征的层归一化
self.norm_context = nn.LayerNorm(context_features)
# 定义查询(q)线性层