@@ -150,6 +150,7 @@ def log_train_info(logger: Logger, info, step, cfg, dataset, is_offline):
150
150
grad_norm = info ["grad_norm" ]
151
151
lr = info ["lr" ]
152
152
update_s = info ["update_s" ]
153
+ dataloading_s = info ["dataloading_s" ]
153
154
154
155
# A sample is an (observation,action) pair, where observation and action
155
156
# can be on multiple timestamps. In a batch, we have `batch_size`` number of samples.
@@ -170,6 +171,7 @@ def log_train_info(logger: Logger, info, step, cfg, dataset, is_offline):
170
171
f"lr:{ lr :0.1e} " ,
171
172
# in seconds
172
173
f"updt_s:{ update_s :.3f} " ,
174
+ f"data_s:{ dataloading_s :.3f} " , # if not ~0, you are bottlenecked by cpu or io
173
175
]
174
176
logging .info (" " .join (log_items ))
175
177
@@ -382,7 +384,10 @@ def evaluate_and_checkpoint_if_needed(step):
382
384
for _ in range (step , cfg .training .offline_steps ):
383
385
if step == 0 :
384
386
logging .info ("Start offline training on a fixed dataset" )
387
+
388
+ start_time = time .perf_counter ()
385
389
batch = next (dl_iter )
390
+ dataloading_s = time .perf_counter () - start_time
386
391
387
392
for key in batch :
388
393
batch [key ] = batch [key ].to (device , non_blocking = True )
@@ -397,6 +402,8 @@ def evaluate_and_checkpoint_if_needed(step):
397
402
use_amp = cfg .use_amp ,
398
403
)
399
404
405
+ train_info ["dataloading_s" ] = dataloading_s
406
+
400
407
if step % cfg .training .log_freq == 0 :
401
408
log_train_info (logger , train_info , step , cfg , offline_dataset , is_offline = True )
402
409
0 commit comments