-
Notifications
You must be signed in to change notification settings - Fork 960
Ability to Iterate Quickly in MJX #2441
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
@mpiseno We hear you and we feel the same pain! We have some pretty neat improvements to JIT speed and more granular JIT caching that we will share in a few weeks' time. I'll come back and leave a comment here when we're ready to share more. p.s. we actually did use to have a way of switching between JAX and numpy in Brax, called jumpy: https://github.com/google/brax/blob/main/brax/v1/jumpy.py You're welcome to play with the idea, but let me warn you that while at first glance it seems elegant, the idea of swapping numpy for JAX is full of really gnarly gotchas. The number of errors and edge cases ultimately made it not worth it in our eyes. |
@jloganolson That is more like logging in the training code, not something you would get from Mujoco / MJX. I would say if you want to log certain values you should add your own print statements to whatever training code you are using. Alternatively, Mujoco playground has an example of using RSL for training code, which is a library of RL algorithm implementations that does logging in the way you’re looking for. |
Hi @jloganolson , regarding logging during training, we added a callback logger in the brax PPO trainer here. If you switch See https://github.com/google-deepmind/mujoco_warp RE reducing JIT times via Warp |
The feature, motivation and pitch
I am a graduate student using MJX for robot locomotion. With the recent release of Mujoco playground, I am interested in using MJX over IsaacGym, but a huge problem I run into is that I cannot effectively debug or iterate small code changes.
It takes more than 5 minutes to trace/compile my jitted reset and step functions, which are functionally very similar to those found here. I have already tried all the typical advice to reduce Jax compilation time (e.g. using jax control flow instead of native python), but the main bottleneck is inside the
mjx.forward
andmjx.step
functions.#1273 touches on this issue, but is somewhat old and the advice of "use Mujoco for development and MJX for actual training" is not very satisfactory. One interesting point from this issue was:
Is this still on the roadmap?
Feature Suggestion
How plausible is it to have a "debug mode" (i.e. using numpy, and doing things on cpu) under the hood for people to easily iterate on development? I am thinking something that allows users to still be able to use the mjx API (i.e. call
mjx.forward
andmjx.step
) but it does not use jax arrays under the hood.I am not too familiar with the internals of MJX, but considering jax.numpy and numpy share the same API, could we get most of the way there by just creating an alias for numpy? I.e.
import numpy as jp
.Alternatives
No response
Additional context
No response
The text was updated successfully, but these errors were encountered: