Skip to content

Commit 9f2d24d

Browse files
authored
Merge PR #545 from Kosinkadink/develop
Added overlap-linear fuse method + cleanup
2 parents b99a0f4 + 6604151 commit 9f2d24d

File tree

4 files changed

+49
-22
lines changed

4 files changed

+49
-22
lines changed

animatediff/context.py

+42-12
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,23 @@ class ContextFuseMethod:
2020
FLAT = "flat"
2121
PYRAMID = "pyramid"
2222
RELATIVE = "relative"
23-
RANDOM = "random"
24-
GAUSS_SIGMA = "gauss-sigma"
25-
GAUSS_SIGMA_INV = "gauss-sigma inverse"
26-
DELAYED_REVERSE_SAWTOOTH = "delayed reverse sawtooth"
27-
PYRAMID_SIGMA = "pyramid-sigma"
28-
PYRAMID_SIGMA_INV = "pyramid-sigma inverse"
23+
OVERLAP_LINEAR = "overlap-linear"
2924

30-
LIST = [PYRAMID, FLAT, DELAYED_REVERSE_SAWTOOTH, PYRAMID_SIGMA, PYRAMID_SIGMA_INV, GAUSS_SIGMA, GAUSS_SIGMA_INV, RANDOM]
31-
LIST_STATIC = [PYRAMID, RELATIVE, FLAT, DELAYED_REVERSE_SAWTOOTH, PYRAMID_SIGMA, PYRAMID_SIGMA_INV, GAUSS_SIGMA, GAUSS_SIGMA_INV, RANDOM]
25+
RANDOM = "🔬random"
26+
RANDOM_DEPR = "random"
27+
GAUSS_SIGMA = "🔬gauss-sigma"
28+
GAUSS_SIGMA_DEPR = "gauss-sigma"
29+
GAUSS_SIGMA_INV = "🔬gauss-sigma inverse"
30+
GAUSS_SIGMA_INV_DEPR = "gauss-sigma inverse"
31+
DELAYED_REVERSE_SAWTOOTH = "🔬delayed reverse sawtooth"
32+
DELAYED_REVERSE_SAWTOOTH_DEPR = "delayed reverse sawtooth"
33+
PYRAMID_SIGMA = "🔬pyramid-sigma"
34+
PYRAMID_SIGMA_DEPR = "pyramid-sigma"
35+
PYRAMID_SIGMA_INV = "🔬pyramid-sigma inverse"
36+
PYRAMID_SIGMA_INV_DEPR = "pyramid-sigma inverse"
37+
38+
LIST = [PYRAMID, FLAT, OVERLAP_LINEAR, DELAYED_REVERSE_SAWTOOTH, PYRAMID_SIGMA, PYRAMID_SIGMA_INV, GAUSS_SIGMA, GAUSS_SIGMA_INV, RANDOM]
39+
LIST_STATIC = [PYRAMID, RELATIVE, FLAT, OVERLAP_LINEAR, DELAYED_REVERSE_SAWTOOTH, PYRAMID_SIGMA, PYRAMID_SIGMA_INV, GAUSS_SIGMA, GAUSS_SIGMA_INV, RANDOM]
3240

3341

3442
class ContextType:
@@ -354,11 +362,11 @@ def get_context_windows(num_frames: int, opts: Union[ContextOptionsGroup, Contex
354362
}
355363

356364

357-
def get_context_weights(num_frames: int, fuse_method: str, sigma: Tensor = None):
358-
weights_func = FUSE_MAPPING.get(fuse_method, None)
365+
def get_context_weights(length: int, full_length: int, idxs: list[int], ctx_opts: ContextOptions, sigma: Tensor=None):
366+
weights_func = FUSE_MAPPING.get(ctx_opts.fuse_method, None)
359367
if not weights_func:
360-
raise ValueError(f"Unknown fuse_method '{fuse_method}'.")
361-
return weights_func(num_frames, sigma=sigma )
368+
raise ValueError(f"Unknown fuse_method '{ctx_opts.fuse_method}'.")
369+
return weights_func(length, sigma=sigma, ctx_opts=ctx_opts, full_length=full_length, idxs=idxs)
362370

363371

364372
def create_weights_flat(length: int, **kwargs) -> list[float]:
@@ -376,6 +384,20 @@ def create_weights_pyramid(length: int, **kwargs) -> list[float]:
376384
weight_sequence = list(range(1, max_weight, 1)) + [max_weight] + list(range(max_weight - 1, 0, -1))
377385
return weight_sequence
378386

387+
def create_weights_overlap_linear(length: int, full_length: int, idxs: list[int], ctx_opts: ContextOptions, **kwargs):
388+
# based on code in Kijai's WanVideoWrapper: https://github.com/kijai/ComfyUI-WanVideoWrapper/blob/dbb2523b37e4ccdf45127e5ae33e31362f755c8e/nodes.py#L1302
389+
# only expected overlap is given different weights
390+
weights_torch = torch.ones((length))
391+
# blend left-side on all except first window
392+
if min(idxs) > 0:
393+
ramp_up = torch.linspace(1e-37, 1, ctx_opts.context_overlap)
394+
weights_torch[:ctx_opts.context_overlap] = ramp_up
395+
# blend right-side on all except last window
396+
if max(idxs) < full_length-1:
397+
ramp_down = torch.linspace(1, 1e-37, ctx_opts.context_overlap)
398+
weights_torch[-ctx_opts.context_overlap:] = ramp_down
399+
return weights_torch
400+
379401
def create_weights_random(length: int, **kwargs) -> list[float]:
380402
if length % 2 == 0:
381403
max_weight = length // 2
@@ -454,12 +476,20 @@ def create_weights_delayed_reverse_sawtooth(length: int, **kwargs) -> list[float
454476
ContextFuseMethod.FLAT: create_weights_flat,
455477
ContextFuseMethod.PYRAMID: create_weights_pyramid,
456478
ContextFuseMethod.RELATIVE: create_weights_pyramid,
479+
ContextFuseMethod.OVERLAP_LINEAR: create_weights_overlap_linear,
480+
# experimental
457481
ContextFuseMethod.GAUSS_SIGMA: create_weights_gauss_sigma,
482+
ContextFuseMethod.GAUSS_SIGMA_DEPR: create_weights_gauss_sigma,
458483
ContextFuseMethod.GAUSS_SIGMA_INV: create_weights_gauss_sigma_inv,
484+
ContextFuseMethod.GAUSS_SIGMA_INV_DEPR: create_weights_gauss_sigma_inv,
459485
ContextFuseMethod.RANDOM: create_weights_random,
486+
ContextFuseMethod.RANDOM_DEPR: create_weights_random,
460487
ContextFuseMethod.DELAYED_REVERSE_SAWTOOTH: create_weights_delayed_reverse_sawtooth,
488+
ContextFuseMethod.DELAYED_REVERSE_SAWTOOTH_DEPR: create_weights_delayed_reverse_sawtooth,
461489
ContextFuseMethod.PYRAMID_SIGMA: create_weights_pyramid_sigma,
490+
ContextFuseMethod.PYRAMID_SIGMA_DEPR: create_weights_pyramid_sigma,
462491
ContextFuseMethod.PYRAMID_SIGMA_INV: create_weights_pyramid_sigma_inv,
492+
ContextFuseMethod.PYRAMID_SIGMA_INV_DEPR: create_weights_pyramid_sigma_inv,
463493
}
464494

465495

animatediff/motion_module_ad.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1387,7 +1387,7 @@ def forward(
13871387
count += 1
13881388
sub_hidden_states = rearrange(sub_hidden_states, "(b f) d c -> b f d c", f=len(sub_idxs))
13891389

1390-
weights = get_context_weights(len(sub_idxs), view_options.fuse_method) * batched_conds
1390+
weights = get_context_weights(len(sub_idxs), video_length, sub_idxs, view_options, sigma=transformer_options["sigmas"]) * batched_conds
13911391
weights_tensor = torch.Tensor(weights).to(device=hidden_states.device).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
13921392
value_final[:, sub_idxs] += sub_hidden_states * weights_tensor
13931393
count_final[:, sub_idxs] += weights_tensor

animatediff/sampling.py

+5-8
Original file line numberDiff line numberDiff line change
@@ -783,9 +783,6 @@ def sliding_calc_cond_batch(executor: Callable, model, conds: list[list[dict]],
783783
multigpu_windows = {}
784784
start_idx = 0
785785
for device, work in ctxs_relative_work.items():
786-
# if device == x_in.device:
787-
# continue
788-
# multigpu_windows[device] = enumerated_context_windows
789786
if work == 0:
790787
continue
791788
end_idx = start_idx + work
@@ -817,14 +814,14 @@ def _handle_context_batch(device: torch.device, batch_windows, model_options_bat
817814

818815
for results in combined_results:
819816
for result in results:
820-
combine_context_window_results(x_in, result.sub_conds_out, result.sub_conds, result.ctx_idxs, result.window_idx, timestep,
817+
combine_context_window_results(x_in, result.sub_conds_out, result.sub_conds, result.ctx_idxs, result.window_idx, len(enumerated_context_windows), timestep,
821818
ADGS, NAIVE, CREF, conds_final, counts_final, biases_final)
822819

823820
else:
824821
for enum_window in enumerated_context_windows:
825822
results = evaluate_context_windows(executor, model, x_in, conds, timestep, [enum_window], model_options, CREF, ADGS)
826823
for result in results:
827-
combine_context_window_results(x_in, result.sub_conds_out, result.sub_conds, result.ctx_idxs, result.window_idx, timestep,
824+
combine_context_window_results(x_in, result.sub_conds_out, result.sub_conds, result.ctx_idxs, result.window_idx, len(enumerated_context_windows), timestep,
828825
ADGS, NAIVE, CREF, conds_final, counts_final, biases_final)
829826
finally:
830827
CREF.cleanup(model_options)
@@ -834,7 +831,7 @@ def _handle_context_batch(device: torch.device, batch_windows, model_options_bat
834831

835832
# finalize conds
836833
if ADGS.params.context_options.fuse_method == ContextFuseMethod.RELATIVE:
837-
# already normalized, so return as is
834+
# relative is already normalized, so return as is
838835
del counts_final
839836
return conds_final
840837
else:
@@ -898,7 +895,7 @@ def evaluate_context_windows(executor, model: BaseModel, x_in: Tensor, conds, ti
898895
return results
899896

900897

901-
def combine_context_window_results(x_in: Tensor, sub_conds_out, sub_conds, ctx_idxs: list[int], window_idx: int, timestep,
898+
def combine_context_window_results(x_in: Tensor, sub_conds_out, sub_conds, ctx_idxs: list[int], window_idx: int, total_windows: int, timestep,
902899
ADGS: AnimateDiffGlobalState, NAIVE: NaiveReuseHandler, CREF: ContextRefHandler,
903900
conds_final: list[Tensor], counts_final: list[Tensor], biases_final: list[Tensor]):
904901
if ADGS.params.context_options.fuse_method == ContextFuseMethod.RELATIVE:
@@ -915,7 +912,7 @@ def combine_context_window_results(x_in: Tensor, sub_conds_out, sub_conds, ctx_i
915912
biases_final[i][idx] = bias_total + bias
916913
else:
917914
# add conds and counts based on weights of fuse method
918-
weights = get_context_weights(len(ctx_idxs), ADGS.params.context_options.fuse_method, sigma=timestep)
915+
weights = get_context_weights(len(ctx_idxs), x_in.shape[0], ctx_idxs, ADGS.params.context_options, sigma=timestep)
919916
weights_tensor = torch.Tensor(weights).to(device=x_in.device).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
920917
for i in range(len(sub_conds_out)):
921918
conds_final[i][ctx_idxs] += sub_conds_out[i] * weights_tensor

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[project]
22
name = "comfyui-animatediff-evolved"
33
description = "Improved AnimateDiff integration for ComfyUI."
4-
version = "1.5.1"
4+
version = "1.5.2"
55
license = { file = "LICENSE" }
66
dependencies = []
77

0 commit comments

Comments
 (0)