Using Catalyst with CUDA-Enabled JAX

Oh one more thing. If you would really like to get CUDA working together with Catalyst, could you try installing JAX via pip as described here? That is, after installing Catalyst, running:

pip install jax[cuda12]

It should install CUDA support for JAX without changing the jaxlib package version.

( And it would be a good idea to remove the previous packages installed by conda/mamba, or just try this in a clean environment :slight_smile: )

2 Likes