-
Notifications
You must be signed in to change notification settings - Fork 55
Description
🚀 Feature
Remove the lax.cond evalutations when possible - which may include some transfomrations in the main algorithm structure when learning.
Motivation
One of the main roles of JAX is to be faster than Torch. Yet, the use of jax.lax.cond makes it harder for the compiler to create efficient computation graphs, and may indeed slow a lot the algorithms.
I figured that out when implementing an end to end RL pipeline (from env to training), and I could reach much higher performances by removing them using the same hardware.
Furthermore, it's often relatively simple to do so, or to find tricks in order not to use them in the RL case.
Pitch
Based on the SAC algorithm, the following lax.cond can be removed:
- sac/SAC/SAC.update_critic (line. 450). Since it occurs every timestep, it is a critical one to remove.
Alternatives
Since this lax.cond is made to compute the policy update every n gradient updates, one could replace _train in sac/SAC/SAC.train (line. 245) directly by a loop that would do k gradient updates, then a policy update, then k gradient updates etc..., all this in a jitted function that would be more efficient.
Another way to remove lax cond, but which would not profit to SBX right now, would be to use lax.select instead of lax.cond. Sometimes, evaluating both branches is faster (but not cheaper) due to the computation graph. It also has to be noted that when using vmap, lax. cond is converted to a lax.select,so both branches are evaluated .
Additional context
I am going to take some time today to propose a code snippet to replace this during the day and see if it would profit the algorithm.
On my side, since the whole learning is jitted, the comparison is biaised and I can't provide a reliable comparison for now.
This implementation could profit to TD3, TQC and CROSSQ as well.
### Checklist
- [ .] I have checked that there is no similar issue in the repo (required)