Skip to content

Commit 5c2f906

Browse files
authored
Merge pull request #1431 from bghira/bugfix/text-enc-device-mismatch
when encoding text embeds, force the move to accelerator on round one
2 parents 8178c21 + 964fa82 commit 5c2f906

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

helpers/data_backend/factory.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -453,9 +453,9 @@ def configure_parquet_database(backend: dict, args, data_backend: BaseDataBacken
453453
)
454454

455455

456-
def move_text_encoders(args, text_encoders: list, target_device: str):
456+
def move_text_encoders(args, text_encoders: list, target_device: str, force_move: bool = False):
457457
"""Move text encoders to the target device."""
458-
if text_encoders is None or not args.offload_during_startup:
458+
if text_encoders is None or (not args.offload_during_startup and not force_move):
459459
return
460460
# we'll move text encoder only if their precision arg is no_change
461461
# otherwise, we assume the user has already moved them to the correct device due to quantisation.
@@ -634,7 +634,7 @@ def configure_multi_databackend(
634634

635635
# Generate a TextEmbeddingCache object
636636
logger.debug(f"rank {get_rank()} is creating TextEmbeddingCache")
637-
move_text_encoders(args, text_encoders, accelerator.device)
637+
move_text_encoders(args, text_encoders, accelerator.device, force_move=True)
638638
init_backend["text_embed_cache"] = TextEmbeddingCache(
639639
id=init_backend["id"],
640640
data_backend=init_backend["data_backend"],

0 commit comments

Comments
 (0)