Skip to content

Commit 72890c3

Browse files
experiment(backend): autocast dtype in CustomLinear
This resolves an issue where specifying `float32` precision causes FLUX Fill to error. I noticed that our other customized torch modules do some dtype casting themselves, so maybe this is a fine place to do this? Maybe this could break things... See #7836
1 parent 92f0c28 commit 72890c3

File tree

2 files changed

+24
-0
lines changed

2 files changed

+24
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from typing import TypeVar
2+
3+
import torch
4+
5+
T = TypeVar("T", torch.Tensor, None, torch.Tensor | None)
6+
7+
8+
def cast_to_dtype(t: T, to_dtype: torch.dtype) -> T:
9+
"""Helper function to cast an optional tensor to a target dtype."""
10+
11+
if t is None:
12+
# If the tensor is None, return it as is.
13+
return t
14+
15+
if t.dtype != to_dtype:
16+
# The tensor is on the wrong device and we don't care about the dtype - or the dtype is already correct.
17+
return t.to(dtype=to_dtype)
18+
19+
return t

invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_linear.py

+5
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import torch
44

55
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
6+
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_dtype import cast_to_dtype
67
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
78
CustomModuleMixin,
89
)
@@ -73,6 +74,10 @@ def _autocast_forward_with_patches(self, input: torch.Tensor) -> torch.Tensor:
7374
def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor:
7475
weight = cast_to_device(self.weight, input.device)
7576
bias = cast_to_device(self.bias, input.device)
77+
78+
weight = cast_to_dtype(weight, input.dtype)
79+
bias = cast_to_dtype(bias, input.dtype)
80+
7681
return torch.nn.functional.linear(input, weight, bias)
7782

7883
def forward(self, input: torch.Tensor) -> torch.Tensor:

0 commit comments

Comments
 (0)