@@ -20,15 +20,23 @@ class ContextFuseMethod:
20
20
FLAT = "flat"
21
21
PYRAMID = "pyramid"
22
22
RELATIVE = "relative"
23
- RANDOM = "random"
24
- GAUSS_SIGMA = "gauss-sigma"
25
- GAUSS_SIGMA_INV = "gauss-sigma inverse"
26
- DELAYED_REVERSE_SAWTOOTH = "delayed reverse sawtooth"
27
- PYRAMID_SIGMA = "pyramid-sigma"
28
- PYRAMID_SIGMA_INV = "pyramid-sigma inverse"
23
+ OVERLAP_LINEAR = "overlap-linear"
29
24
30
- LIST = [PYRAMID , FLAT , DELAYED_REVERSE_SAWTOOTH , PYRAMID_SIGMA , PYRAMID_SIGMA_INV , GAUSS_SIGMA , GAUSS_SIGMA_INV , RANDOM ]
31
- LIST_STATIC = [PYRAMID , RELATIVE , FLAT , DELAYED_REVERSE_SAWTOOTH , PYRAMID_SIGMA , PYRAMID_SIGMA_INV , GAUSS_SIGMA , GAUSS_SIGMA_INV , RANDOM ]
25
+ RANDOM = "🔬random"
26
+ RANDOM_DEPR = "random"
27
+ GAUSS_SIGMA = "🔬gauss-sigma"
28
+ GAUSS_SIGMA_DEPR = "gauss-sigma"
29
+ GAUSS_SIGMA_INV = "🔬gauss-sigma inverse"
30
+ GAUSS_SIGMA_INV_DEPR = "gauss-sigma inverse"
31
+ DELAYED_REVERSE_SAWTOOTH = "🔬delayed reverse sawtooth"
32
+ DELAYED_REVERSE_SAWTOOTH_DEPR = "delayed reverse sawtooth"
33
+ PYRAMID_SIGMA = "🔬pyramid-sigma"
34
+ PYRAMID_SIGMA_DEPR = "pyramid-sigma"
35
+ PYRAMID_SIGMA_INV = "🔬pyramid-sigma inverse"
36
+ PYRAMID_SIGMA_INV_DEPR = "pyramid-sigma inverse"
37
+
38
+ LIST = [PYRAMID , FLAT , OVERLAP_LINEAR , DELAYED_REVERSE_SAWTOOTH , PYRAMID_SIGMA , PYRAMID_SIGMA_INV , GAUSS_SIGMA , GAUSS_SIGMA_INV , RANDOM ]
39
+ LIST_STATIC = [PYRAMID , RELATIVE , FLAT , OVERLAP_LINEAR , DELAYED_REVERSE_SAWTOOTH , PYRAMID_SIGMA , PYRAMID_SIGMA_INV , GAUSS_SIGMA , GAUSS_SIGMA_INV , RANDOM ]
32
40
33
41
34
42
class ContextType :
@@ -354,11 +362,11 @@ def get_context_windows(num_frames: int, opts: Union[ContextOptionsGroup, Contex
354
362
}
355
363
356
364
357
- def get_context_weights (num_frames : int , fuse_method : str , sigma : Tensor = None ):
358
- weights_func = FUSE_MAPPING .get (fuse_method , None )
365
+ def get_context_weights (length : int , full_length : int , idxs : list [ int ], ctx_opts : ContextOptions , sigma : Tensor = None ):
366
+ weights_func = FUSE_MAPPING .get (ctx_opts . fuse_method , None )
359
367
if not weights_func :
360
- raise ValueError (f"Unknown fuse_method '{ fuse_method } '." )
361
- return weights_func (num_frames , sigma = sigma )
368
+ raise ValueError (f"Unknown fuse_method '{ ctx_opts . fuse_method } '." )
369
+ return weights_func (length , sigma = sigma , ctx_opts = ctx_opts , full_length = full_length , idxs = idxs )
362
370
363
371
364
372
def create_weights_flat (length : int , ** kwargs ) -> list [float ]:
@@ -376,6 +384,20 @@ def create_weights_pyramid(length: int, **kwargs) -> list[float]:
376
384
weight_sequence = list (range (1 , max_weight , 1 )) + [max_weight ] + list (range (max_weight - 1 , 0 , - 1 ))
377
385
return weight_sequence
378
386
387
+ def create_weights_overlap_linear (length : int , full_length : int , idxs : list [int ], ctx_opts : ContextOptions , ** kwargs ):
388
+ # based on code in Kijai's WanVideoWrapper: https://github.com/kijai/ComfyUI-WanVideoWrapper/blob/dbb2523b37e4ccdf45127e5ae33e31362f755c8e/nodes.py#L1302
389
+ # only expected overlap is given different weights
390
+ weights_torch = torch .ones ((length ))
391
+ # blend left-side on all except first window
392
+ if min (idxs ) > 0 :
393
+ ramp_up = torch .linspace (1e-37 , 1 , ctx_opts .context_overlap )
394
+ weights_torch [:ctx_opts .context_overlap ] = ramp_up
395
+ # blend right-side on all except last window
396
+ if max (idxs ) < full_length - 1 :
397
+ ramp_down = torch .linspace (1 , 1e-37 , ctx_opts .context_overlap )
398
+ weights_torch [- ctx_opts .context_overlap :] = ramp_down
399
+ return weights_torch
400
+
379
401
def create_weights_random (length : int , ** kwargs ) -> list [float ]:
380
402
if length % 2 == 0 :
381
403
max_weight = length // 2
@@ -454,12 +476,20 @@ def create_weights_delayed_reverse_sawtooth(length: int, **kwargs) -> list[float
454
476
ContextFuseMethod .FLAT : create_weights_flat ,
455
477
ContextFuseMethod .PYRAMID : create_weights_pyramid ,
456
478
ContextFuseMethod .RELATIVE : create_weights_pyramid ,
479
+ ContextFuseMethod .OVERLAP_LINEAR : create_weights_overlap_linear ,
480
+ # experimental
457
481
ContextFuseMethod .GAUSS_SIGMA : create_weights_gauss_sigma ,
482
+ ContextFuseMethod .GAUSS_SIGMA_DEPR : create_weights_gauss_sigma ,
458
483
ContextFuseMethod .GAUSS_SIGMA_INV : create_weights_gauss_sigma_inv ,
484
+ ContextFuseMethod .GAUSS_SIGMA_INV_DEPR : create_weights_gauss_sigma_inv ,
459
485
ContextFuseMethod .RANDOM : create_weights_random ,
486
+ ContextFuseMethod .RANDOM_DEPR : create_weights_random ,
460
487
ContextFuseMethod .DELAYED_REVERSE_SAWTOOTH : create_weights_delayed_reverse_sawtooth ,
488
+ ContextFuseMethod .DELAYED_REVERSE_SAWTOOTH_DEPR : create_weights_delayed_reverse_sawtooth ,
461
489
ContextFuseMethod .PYRAMID_SIGMA : create_weights_pyramid_sigma ,
490
+ ContextFuseMethod .PYRAMID_SIGMA_DEPR : create_weights_pyramid_sigma ,
462
491
ContextFuseMethod .PYRAMID_SIGMA_INV : create_weights_pyramid_sigma_inv ,
492
+ ContextFuseMethod .PYRAMID_SIGMA_INV_DEPR : create_weights_pyramid_sigma_inv ,
463
493
}
464
494
465
495
0 commit comments