Skip to content

Commit 33362db

Browse files
amandip7Cadene
andauthored
Adding parameter dataloading_s to console logs and wandb for tracking… (#243)
Co-authored-by: Remi <[email protected]>
1 parent b0d954c commit 33362db

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

lerobot/scripts/train.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ def log_train_info(logger: Logger, info, step, cfg, dataset, is_offline):
150150
grad_norm = info["grad_norm"]
151151
lr = info["lr"]
152152
update_s = info["update_s"]
153+
dataloading_s = info["dataloading_s"]
153154

154155
# A sample is an (observation,action) pair, where observation and action
155156
# can be on multiple timestamps. In a batch, we have `batch_size`` number of samples.
@@ -170,6 +171,7 @@ def log_train_info(logger: Logger, info, step, cfg, dataset, is_offline):
170171
f"lr:{lr:0.1e}",
171172
# in seconds
172173
f"updt_s:{update_s:.3f}",
174+
f"data_s:{dataloading_s:.3f}", # if not ~0, you are bottlenecked by cpu or io
173175
]
174176
logging.info(" ".join(log_items))
175177

@@ -382,7 +384,10 @@ def evaluate_and_checkpoint_if_needed(step):
382384
for _ in range(step, cfg.training.offline_steps):
383385
if step == 0:
384386
logging.info("Start offline training on a fixed dataset")
387+
388+
start_time = time.perf_counter()
385389
batch = next(dl_iter)
390+
dataloading_s = time.perf_counter() - start_time
386391

387392
for key in batch:
388393
batch[key] = batch[key].to(device, non_blocking=True)
@@ -397,6 +402,8 @@ def evaluate_and_checkpoint_if_needed(step):
397402
use_amp=cfg.use_amp,
398403
)
399404

405+
train_info["dataloading_s"] = dataloading_s
406+
400407
if step % cfg.training.log_freq == 0:
401408
log_train_info(logger, train_info, step, cfg, offline_dataset, is_offline=True)
402409

0 commit comments

Comments
 (0)