Skip to content

Helpful messages #210

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 11, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 44 additions & 10 deletions finetrainers/args.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import sys
from typing import Any, Dict, List, Optional, Tuple

import torch
Expand All @@ -11,6 +12,13 @@ class Args:
r"""
The arguments for the finetrainers training script.

For helpful information about arguments, run `python train.py --help`.

TODO(aryan): add `python train.py --recommend_configs --model_name <model_name>` to recommend
good training configs for a model after extensive testing.
TODO(aryan): add `python train.py --memory_requirements --model_name <model_name>` to show
memory requirements per model, per training type with sensible training settings.

MODEL ARGUMENTS
---------------
model_name (`str`):
Expand Down Expand Up @@ -424,20 +432,31 @@ def to_dict(self) -> Dict[str, Any]:
}


# TODO(aryan): handle more informative messages
_IS_ARGUMENTS_REQUIRED = "--list_models" not in sys.argv


def parse_arguments() -> Args:
parser = argparse.ArgumentParser()

_add_model_arguments(parser)
_add_dataset_arguments(parser)
_add_dataloader_arguments(parser)
_add_diffusion_arguments(parser)
_add_training_arguments(parser)
_add_optimizer_arguments(parser)
_add_validation_arguments(parser)
_add_miscellaneous_arguments(parser)
if _IS_ARGUMENTS_REQUIRED:
_add_model_arguments(parser)
_add_dataset_arguments(parser)
_add_dataloader_arguments(parser)
_add_diffusion_arguments(parser)
_add_training_arguments(parser)
_add_optimizer_arguments(parser)
_add_validation_arguments(parser)
_add_miscellaneous_arguments(parser)

args = parser.parse_args()
return _map_to_args_type(args)
else:
_add_helper_arguments(parser)

args = parser.parse_args()
return _map_to_args_type(args)
args = parser.parse_args()
_display_helper_messages(args)
sys.exit(0)


def validate_args(args: Args):
Expand Down Expand Up @@ -932,6 +951,14 @@ def _add_miscellaneous_arguments(parser: argparse.ArgumentParser) -> None:
)


def _add_helper_arguments(parser: argparse.ArgumentParser) -> None:
parser.add_argument(
"--list_models",
action="store_true",
help="List all the supported models.",
)


_DTYPE_MAP = {
"bf16": torch.bfloat16,
"fp16": torch.float16,
Expand Down Expand Up @@ -1089,3 +1116,10 @@ def _validate_validation_args(args: Args):
assert len(args.validation_prompts) == len(
args.validation_widths
), "Validation prompts and widths should be of same length"


def _display_helper_messages(args: argparse.Namespace):
if args.list_models:
print("Supported models:")
for index, model_name in enumerate(SUPPORTED_MODEL_CONFIGS.keys()):
print(f" {index + 1}. {model_name}")