-
Notifications
You must be signed in to change notification settings - Fork 233
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Calling JAXAgent train gets stuck if using larger image sizes (inside Ninjax) #85
Comments
Sometimes the trace can take a while with old GPUs, I've waited around 10 minutes for a TitanX workstation before. You can try making the CNN smaller to see if that speeds up compilation time. You can also try incrementally increasing the resolution and check if the trace time increases. |
Thanks. I am not sure if time and compute power is really the problem. Even after 24 hours, it did not trace on an A100. But I will test how tracing time increases with increasing image resolution and report my findings here. |
I worked a little bit more on this topic and found out that the Furthermore, I tracked the problem a little bit more down and it seems to arise in the Lines 60 to 101 in 8fa35f8
|
Hi Danijar,
I am currently trying to use higher image resolutions like 256x256 for Dreamer. By simply changing the resolution e.g. for DM control suite, JAX is not able to trace/compile the training function anymore:
But instead of an error the program seems to be stuck at/after the point where it tries to trace the training function with JAX:
I have tested this on a V100 and an A100. Both with the same result. With smaller resolutions (e.g. 128x128 or 64x64) this works of course.
I tried to debug this but I am not really able to track this down inside Ninjax or Jax.
Thanks a lot for your help!
The text was updated successfully, but these errors were encountered: