@@ -209,7 +209,7 @@ def eval_policy(
209
209
policy : torch .nn .Module ,
210
210
n_episodes : int ,
211
211
max_episodes_rendered : int = 0 ,
212
- video_dir : Path | None = None ,
212
+ videos_dir : Path | None = None ,
213
213
return_episode_data : bool = False ,
214
214
start_seed : int | None = None ,
215
215
enable_progbar : bool = False ,
@@ -347,8 +347,8 @@ def render_frame(env: gym.vector.VectorEnv):
347
347
):
348
348
if n_episodes_rendered >= max_episodes_rendered :
349
349
break
350
- video_dir .mkdir (parents = True , exist_ok = True )
351
- video_path = video_dir / f"eval_episode_{ n_episodes_rendered } .mp4"
350
+ videos_dir .mkdir (parents = True , exist_ok = True )
351
+ video_path = videos_dir / f"eval_episode_{ n_episodes_rendered } .mp4"
352
352
video_paths .append (str (video_path ))
353
353
thread = threading .Thread (
354
354
target = write_video ,
@@ -503,22 +503,19 @@ def _compile_episode_data(
503
503
}
504
504
505
505
506
- def eval (
506
+ def main (
507
507
pretrained_policy_path : str | None = None ,
508
508
hydra_cfg_path : str | None = None ,
509
+ out_dir : str | None = None ,
509
510
config_overrides : list [str ] | None = None ,
510
511
):
511
512
assert (pretrained_policy_path is None ) ^ (hydra_cfg_path is None )
512
513
if hydra_cfg_path is None :
513
514
hydra_cfg = init_hydra_config (pretrained_policy_path / "config.yaml" , config_overrides )
514
515
else :
515
516
hydra_cfg = init_hydra_config (hydra_cfg_path , config_overrides )
516
- out_dir = (
517
- f"outputs/eval/{ dt .now ().strftime ('%Y-%m-%d/%H-%M-%S' )} _{ hydra_cfg .env .name } _{ hydra_cfg .policy .name } "
518
- )
519
-
520
517
if out_dir is None :
521
- raise NotImplementedError ()
518
+ out_dir = f"outputs/eval/ { dt . now (). strftime ( '%Y-%m-%d/%H-%M-%S' ) } _ { hydra_cfg . env . name } _ { hydra_cfg . policy . name } "
522
519
523
520
# Check device is available
524
521
device = get_safe_torch_device (hydra_cfg .device , log = True )
@@ -546,7 +543,7 @@ def eval(
546
543
policy ,
547
544
hydra_cfg .eval .n_episodes ,
548
545
max_episodes_rendered = 10 ,
549
- video_dir = Path (out_dir ) / "eval " ,
546
+ videos_dir = Path (out_dir ) / "videos " ,
550
547
start_seed = hydra_cfg .seed ,
551
548
enable_progbar = True ,
552
549
enable_inner_progbar = True ,
@@ -586,6 +583,13 @@ def eval(
586
583
),
587
584
)
588
585
parser .add_argument ("--revision" , help = "Optionally provide the Hugging Face Hub revision ID." )
586
+ parser .add_argument (
587
+ "--out-dir" ,
588
+ help = (
589
+ "Where to save the evaluation outputs. If not provided, outputs are saved in "
590
+ "outputs/eval/{timestamp}_{env_name}_{policy_name}"
591
+ ),
592
+ )
589
593
parser .add_argument (
590
594
"overrides" ,
591
595
nargs = "*" ,
@@ -594,7 +598,7 @@ def eval(
594
598
args = parser .parse_args ()
595
599
596
600
if args .pretrained_policy_name_or_path is None :
597
- eval (hydra_cfg_path = args .config , config_overrides = args .overrides )
601
+ main (hydra_cfg_path = args .config , out_dir = args . out_dir , config_overrides = args .overrides )
598
602
else :
599
603
try :
600
604
pretrained_policy_path = Path (
@@ -618,4 +622,8 @@ def eval(
618
622
"repo ID, nor is it an existing local directory."
619
623
)
620
624
621
- eval (pretrained_policy_path = pretrained_policy_path , config_overrides = args .overrides )
625
+ main (
626
+ pretrained_policy_path = pretrained_policy_path ,
627
+ out_dir = args .out_dir ,
628
+ config_overrides = args .overrides ,
629
+ )
0 commit comments