Skip to content

[WIP][tests] add precomputation tests #234

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Feb 24, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion finetrainers/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,7 @@ class Args:
validation_every_n_steps: Optional[int] = None
enable_model_cpu_offload: bool = False
validation_frame_rate: int = 25
do_not_run_validation: bool = False

# Miscellaneous arguments
tracker_name: str = "finetrainers"
Expand Down Expand Up @@ -483,7 +484,8 @@ def parse_arguments() -> Args:
def validate_args(args: Args):
_validated_model_args(args)
_validate_training_args(args)
_validate_validation_args(args)
if not args.do_not_run_validation:
_validate_validation_args(args)


def _add_model_arguments(parser: argparse.ArgumentParser) -> None:
Expand Down
7 changes: 5 additions & 2 deletions finetrainers/models/cogvideox/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
from diffusers import AutoencoderKLCogVideoX, CogVideoXDDIMScheduler, CogVideoXPipeline, CogVideoXTransformer3DModel
from PIL import Image
from transformers import T5EncoderModel, T5Tokenizer
from transformers import T5EncoderModel, T5Tokenizer, AutoTokenizer

from .utils import prepare_rotary_positional_embeddings

Expand All @@ -15,7 +15,10 @@ def load_condition_models(
cache_dir: Optional[str] = None,
**kwargs,
):
tokenizer = T5Tokenizer.from_pretrained(model_id, subfolder="tokenizer", revision=revision, cache_dir=cache_dir)
try:
tokenizer = T5Tokenizer.from_pretrained(model_id, subfolder="tokenizer", revision=revision, cache_dir=cache_dir)
except:
tokenizer = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer", revision=revision, cache_dir=cache_dir)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not super proud of this but we cannot do T5Tokenizer on the dummy T5 tokenizer checkpoint. Some sentencepiece error.

text_encoder = T5EncoderModel.from_pretrained(
model_id, subfolder="text_encoder", torch_dtype=text_encoder_dtype, revision=revision, cache_dir=cache_dir
)
Expand Down
12 changes: 8 additions & 4 deletions finetrainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,8 @@ def collate_fn(batch):

memory_statistics = get_memory_statistics()
logger.info(f"Memory after precomputing conditions: {json.dumps(memory_statistics, indent=4)}")
torch.cuda.reset_peak_memory_stats(accelerator.device)
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats(accelerator.device)

# Precompute latents
latent_components = self.model_config["load_latent_models"](**self._get_load_components_kwargs())
Expand Down Expand Up @@ -302,7 +303,8 @@ def collate_fn(batch):

memory_statistics = get_memory_statistics()
logger.info(f"Memory after precomputing latents: {json.dumps(memory_statistics, indent=4)}")
torch.cuda.reset_peak_memory_stats(accelerator.device)
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats(accelerator.device)

# Update dataloader to use precomputed conditions and latents
self.dataloader = torch.utils.data.DataLoader(
Expand Down Expand Up @@ -984,7 +986,8 @@ def validate(self, step: int, final_validation: bool = False) -> None:
free_memory()
memory_statistics = get_memory_statistics()
logger.info(f"Memory after validation end: {json.dumps(memory_statistics, indent=4)}")
torch.cuda.reset_peak_memory_stats(accelerator.device)
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats(accelerator.device)

if not final_validation:
self.transformer.train()
Expand Down Expand Up @@ -1107,7 +1110,8 @@ def _delete_components(self) -> None:
self.vae = None
self.scheduler = None
free_memory()
torch.cuda.synchronize(self.state.accelerator.device)
if torch.cuda.is_available():
torch.cuda.synchronize(self.state.accelerator.device)

def _get_and_prepare_pipeline_for_validation(self, final_validation: bool = False) -> DiffusionPipeline:
accelerator = self.state.accelerator
Expand Down
8 changes: 4 additions & 4 deletions finetrainers/utils/memory_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ def get_memory_statistics(precision: int = 3) -> Dict[str, Any]:
logger.warning("No CUDA, MPS, or ROCm device found. Memory statistics are not available.")

return {
"memory_allocated": round(bytes_to_gigabytes(memory_allocated), ndigits=precision),
"memory_reserved": round(bytes_to_gigabytes(memory_reserved), ndigits=precision),
"max_memory_allocated": round(bytes_to_gigabytes(max_memory_allocated), ndigits=precision),
"max_memory_reserved": round(bytes_to_gigabytes(max_memory_reserved), ndigits=precision),
"memory_allocated": round(bytes_to_gigabytes(memory_allocated), ndigits=precision) if memory_allocated else None,
"memory_reserved": round(bytes_to_gigabytes(memory_reserved), ndigits=precision) if memory_reserved else None,
"max_memory_allocated": round(bytes_to_gigabytes(max_memory_allocated), ndigits=precision) if max_memory_allocated else None,
"max_memory_reserved": round(bytes_to_gigabytes(max_memory_reserved), ndigits=precision) if max_memory_reserved else None,
}


Expand Down
Empty file added tests/trainers/__init__.py
Empty file.
Empty file.
31 changes: 31 additions & 0 deletions tests/trainers/cogvideox/test_cogvideox.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import sys
from pathlib import Path

current_file = Path(__file__).resolve()
root_dir = current_file.parents[3]
sys.path.append(str(root_dir))

from ..test_trainers_common import TrainerTestMixin
from typing import Tuple
from finetrainers import Args
import unittest

# Copied for now.
def parse_resolution_bucket(resolution_bucket: str) -> Tuple[int, ...]:
return tuple(map(int, resolution_bucket.split("x")))



class CogVideoXTester(unittest.TestCase, TrainerTestMixin):
model_name = "cogvideox"

def get_training_args(self):
args = Args()
args.model_name = self.model_name
args.training_type = "lora"
args.pretrained_model_name_or_path = "finetrainers/dummy-cogvideox"
args.data_root = "" # will be set from the tester method.
args.video_resolution_buckets = [parse_resolution_bucket("9x16x16")]
args.precompute_conditions = True
args.do_not_run_validation = True
return args
63 changes: 63 additions & 0 deletions tests/trainers/test_trainers_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import sys
from pathlib import Path

current_file = Path(__file__).resolve()
root_dir = current_file.parents[1]
sys.path.append(str(root_dir))


from finetrainers import Trainer
from finetrainers.utils.file_utils import string_to_filename
from finetrainers.constants import PRECOMPUTED_DIR_NAME, PRECOMPUTED_CONDITIONS_DIR_NAME, PRECOMPUTED_LATENTS_DIR_NAME
from huggingface_hub import snapshot_download
import tempfile
import glob
import os

class TrainerTestMixin:
model_name = None

def get_training_args(self):
raise NotImplementedError

def download_dataset_txt_format(self, cache_dir):
path = snapshot_download(repo_id="finetrainers/dummy-disney-dataset", repo_type="dataset", cache_dir=cache_dir)
return path

def test_precomputation_txt_format(self):
# Here we assume the dataset is formatted like:
# https://huggingface.co/datasets/Wild-Heart/Disney-VideoGeneration-Dataset/tree/main
training_args = self.get_training_args()

with tempfile.TemporaryDirectory() as tmpdir:
# Prepare remaining args.
training_args.data_root = Path(self.download_dataset_txt_format(cache_dir=tmpdir))

training_args.video_column = "videos.txt"
training_args.caption_column = "prompt.txt"
with open(f"{training_args.data_root}/{training_args.video_column}", "r", encoding="utf-8") as file:
video_paths = [training_args.data_root.joinpath(line.strip()) for line in file.readlines() if len(line.strip()) > 0]

# Initialize trainer.
training_args.output_dir = tmpdir
trainer = Trainer(training_args)
training_args = trainer.args

# Perform precomputations.
trainer.prepare_dataset()
trainer.prepare_models()
trainer.prepare_precomputations()

cleaned_model_id = string_to_filename(training_args.pretrained_model_name_or_path)
precomputation_dir = (
Path(training_args.data_root) / f"{training_args.model_name}_{cleaned_model_id}_{PRECOMPUTED_DIR_NAME}"
)

# Checks.
conditions_dir = precomputation_dir / PRECOMPUTED_CONDITIONS_DIR_NAME
latents_dir = precomputation_dir / PRECOMPUTED_LATENTS_DIR_NAME
assert os.path.exists(precomputation_dir), f"Precomputation wasn't successful. Couldn't find the precomputed dir: {os.listdir(training_args.data_root)=}\n"
assert os.path.exists(conditions_dir), f"conditions dir ({str(conditions_dir)}) doesn't exist"
assert os.path.exists(latents_dir), f"latents dir ({str(latents_dir)}) doesn't exist"
assert len(video_paths) == len([p for p in conditions_dir.glob("*.pt")])
assert len(video_paths) == len([p for p in latents_dir.glob("*.pt")])
Loading