Replies: 1 comment
-
Whilst I don't really have the answers to these, this thread is also relevant/interesting: |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Is there a reference to some best generic practices for data loading? I know this is a very generic question, but I've noted that because Jax doesn't have pre-built dataloaders, I end up using some combination of Tensorflow and Jax, and I'm never sure whether what I'm doing is optimal. I think in particular there are three questions that would be good to have documented somewhere (I've read Jax's memory model documentation page, but still wasn't sure)
ds.map(lambda x: jnp.array(x))
, but I believe people have had performance issues there. I've also seen it usingdlpack
as a stepping stone (see https://stackoverflow.com/questions/69782818/turn-a-tf-data-dataset-to-a-jax-numpy-iterator).tf.device
, and in jax, we have methods likejax.device_put
. Is the correct pattern to load the dataset using tensorflow onto the cpu, then use jax on the GPU? e.g., something likedevice_put_sharded
? In particular, often I'm on a cloud instance with a gpu, butjax.devices()
only shows the gpu.jax.devices('cpu')
, however, will show a cpu; despite this, any of the "sharding" will throw an error because there's only one device injax.devices()
, when I would expect there to be two.Beta Was this translation helpful? Give feedback.
All reactions