Skip to content

Commit 21f222f

Browse files
Add out_dir option to eval (#244)
1 parent 33362db commit 21f222f

File tree

2 files changed

+25
-16
lines changed

2 files changed

+25
-16
lines changed

lerobot/scripts/eval.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ def eval_policy(
209209
policy: torch.nn.Module,
210210
n_episodes: int,
211211
max_episodes_rendered: int = 0,
212-
video_dir: Path | None = None,
212+
videos_dir: Path | None = None,
213213
return_episode_data: bool = False,
214214
start_seed: int | None = None,
215215
enable_progbar: bool = False,
@@ -347,8 +347,8 @@ def render_frame(env: gym.vector.VectorEnv):
347347
):
348348
if n_episodes_rendered >= max_episodes_rendered:
349349
break
350-
video_dir.mkdir(parents=True, exist_ok=True)
351-
video_path = video_dir / f"eval_episode_{n_episodes_rendered}.mp4"
350+
videos_dir.mkdir(parents=True, exist_ok=True)
351+
video_path = videos_dir / f"eval_episode_{n_episodes_rendered}.mp4"
352352
video_paths.append(str(video_path))
353353
thread = threading.Thread(
354354
target=write_video,
@@ -503,22 +503,19 @@ def _compile_episode_data(
503503
}
504504

505505

506-
def eval(
506+
def main(
507507
pretrained_policy_path: str | None = None,
508508
hydra_cfg_path: str | None = None,
509+
out_dir: str | None = None,
509510
config_overrides: list[str] | None = None,
510511
):
511512
assert (pretrained_policy_path is None) ^ (hydra_cfg_path is None)
512513
if hydra_cfg_path is None:
513514
hydra_cfg = init_hydra_config(pretrained_policy_path / "config.yaml", config_overrides)
514515
else:
515516
hydra_cfg = init_hydra_config(hydra_cfg_path, config_overrides)
516-
out_dir = (
517-
f"outputs/eval/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{hydra_cfg.env.name}_{hydra_cfg.policy.name}"
518-
)
519-
520517
if out_dir is None:
521-
raise NotImplementedError()
518+
out_dir = f"outputs/eval/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{hydra_cfg.env.name}_{hydra_cfg.policy.name}"
522519

523520
# Check device is available
524521
device = get_safe_torch_device(hydra_cfg.device, log=True)
@@ -546,7 +543,7 @@ def eval(
546543
policy,
547544
hydra_cfg.eval.n_episodes,
548545
max_episodes_rendered=10,
549-
video_dir=Path(out_dir) / "eval",
546+
videos_dir=Path(out_dir) / "videos",
550547
start_seed=hydra_cfg.seed,
551548
enable_progbar=True,
552549
enable_inner_progbar=True,
@@ -586,6 +583,13 @@ def eval(
586583
),
587584
)
588585
parser.add_argument("--revision", help="Optionally provide the Hugging Face Hub revision ID.")
586+
parser.add_argument(
587+
"--out-dir",
588+
help=(
589+
"Where to save the evaluation outputs. If not provided, outputs are saved in "
590+
"outputs/eval/{timestamp}_{env_name}_{policy_name}"
591+
),
592+
)
589593
parser.add_argument(
590594
"overrides",
591595
nargs="*",
@@ -594,7 +598,7 @@ def eval(
594598
args = parser.parse_args()
595599

596600
if args.pretrained_policy_name_or_path is None:
597-
eval(hydra_cfg_path=args.config, config_overrides=args.overrides)
601+
main(hydra_cfg_path=args.config, out_dir=args.out_dir, config_overrides=args.overrides)
598602
else:
599603
try:
600604
pretrained_policy_path = Path(
@@ -618,4 +622,8 @@ def eval(
618622
"repo ID, nor is it an existing local directory."
619623
)
620624

621-
eval(pretrained_policy_path=pretrained_policy_path, config_overrides=args.overrides)
625+
main(
626+
pretrained_policy_path=pretrained_policy_path,
627+
out_dir=args.out_dir,
628+
config_overrides=args.overrides,
629+
)

lerobot/scripts/train.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -327,14 +327,17 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
327327

328328
# Note: this helper will be used in offline and online training loops.
329329
def evaluate_and_checkpoint_if_needed(step):
330+
_num_digits = max(6, len(str(cfg.training.offline_steps + cfg.training.online_steps)))
331+
step_identifier = f"{step:0{_num_digits}d}"
332+
330333
if cfg.training.eval_freq > 0 and step % cfg.training.eval_freq == 0:
331334
logging.info(f"Eval policy at step {step}")
332335
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext():
333336
eval_info = eval_policy(
334337
eval_env,
335338
policy,
336339
cfg.eval.n_episodes,
337-
video_dir=Path(out_dir) / "eval",
340+
videos_dir=Path(out_dir) / "eval" / f"videos_step_{step_identifier}",
338341
max_episodes_rendered=4,
339342
start_seed=cfg.seed,
340343
)
@@ -352,9 +355,7 @@ def evaluate_and_checkpoint_if_needed(step):
352355
policy,
353356
optimizer,
354357
lr_scheduler,
355-
identifier=str(step).zfill(
356-
max(6, len(str(cfg.training.offline_steps + cfg.training.online_steps)))
357-
),
358+
identifier=step_identifier,
358359
)
359360
logging.info("Resume training")
360361

0 commit comments

Comments
 (0)