Hi @MKK_QML , welcome to the Forum!
This is probably due to the version of JAX you’re using. If you use JAX version 0.4.33 you probably won’t see the issue. If you use JAX >=0.5 then you probably will.
Check out thread #8053 if you need some help on how to downgrade your JAX version.
Let us know if you’re having trouble downgrading or if you still face this issue after downgrading.
And please let us know if this worked to solve the issue!