Example for flax.jax_utils.prefetch_to_device with TFDS dataset #3869
Replies: 2 comments
-
Have you tried calling |
Beta Was this translation helpful? Give feedback.
0 replies
-
It took me a while to understand how the batch dimensions were changed. I found it did not help with the training performance, at least not for a single GPU case. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hi,
I want to use
flax.jax_utils.prefetch_to_device
to preload data to my GPU. But I could not figure out how to work with a TFDS dataset.The dataset is simple.
The above code works without using
prefetch_to_device
. But I cannot simply callprefetch_to_device(train_ds.as_numpy_iterator())
becauseprefetch_to_device
requires the first dimension of the iterator to be the number of devices. I only have one GPU, so it expects theas_numpy_iterator
to return an iterator in the shape of(1, batch_size, ...)
.I cannot find a way to make
as_numpy_iterator
to return one more dimension. Maybe I need a new way to constract the training dataset? Anyway, it'd be very appreciated if I could get some code snippet or examples.Thank you.
Beta Was this translation helpful? Give feedback.
All reactions