Using Catalyst with CUDA-Enabled JAX

Hello!
I am attempting to use Catalyst to write my code, but I am encountering issues with the compatibility between Catalyst and the CUDA version of jaxlib. Specifically, since Catalyst does not seem to use conda for package management, I first used conda to install jax[cuda12] and other dependencies. Then, I installed Catalyst with pip install pennylane-catalyst. However, I noticed that the JAX version was modified, and I received the following warning:

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed.

To address this, I tried reinstalling jaxlib using:

mamba install "jaxlib==0.4.23=cuda12*" -c conda-forge

After this, JAX imports normally. However, when I try to import Catalyst, I get the following warning:

/site-packages/catalyst/__init__.py:30: UserWarning: Catalyst detected a version mismatch for the installed 'jaxlib' package. Please make sure to install the exact version required by Catalyst to avoid undefined behavior.
Expected: 0.4.23 Found: 0.4.23.dev20240522

It seems that the CUDA version of jaxlib installed via conda/mamba always includes additional numbers that prevent it from matching correctly. I am unsure how to avoid this issue or if this is expected behavior. Ignoring the warning, I proceeded to run the code from this notebook. However, when I reached the cell input[8]:

loss_fn(params, data, targets)

I encountered the following error:

----> 1 loss_fn(params, data, targets)

File ~/miniforge3/envs/jax/lib/python3.12/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
     54     s_caller = "::L".join(
     55         [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
     56     )
     57     lgr.debug(
     58         f"Calling {f_string} from {s_caller}",
     59         **_debug_log_kwargs,
     60     )
---> 61 return func(*args, **kwargs)

File ~/miniforge3/envs/jax/lib/python3.12/site-packages/catalyst/jit.py:499, in QJIT.__call__(self, *args, **kwargs)
    496 if EvaluationContext.is_tracing():
    497     return self.user_function(*args, **kwargs)
--> 499 requires_promotion = self.jit_compile(args)
    501 # If we receive tracers as input, dispatch to the JAX integration.
    502 if any(isinstance(arg, jax.core.Tracer) for arg in tree_flatten(args)[0]):

File ~/miniforge3/envs/jax/lib/python3.12/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
     54     s_caller = "::L".join(
     55         [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
     56     )
     57     lgr.debug(
     58         f"Calling {f_string} from {s_caller}",
     59         **_debug_log_kwargs,
     60     )
---> 61 return func(*args, **kwargs)

File ~/miniforge3/envs/jax/lib/python3.12/site-packages/catalyst/jit.py:573, in QJIT.jit_compile(self, args)
    570         self.jaxpr, self.out_type, self.out_treedef, self.c_sig = self.capture(args)
    572     self.mlir_module, self.mlir = self.generate_ir()
--> 573     self.compiled_function, self.qir = self.compile()
    575     self.fn_cache.insert(self.compiled_function, args, self.out_treedef, self.workspace)
    577 elif self.compiled_function is not cached_fn.compiled_fn:
    578     # Restore active state from cache.

File ~/miniforge3/envs/jax/lib/python3.12/site-packages/catalyst/debug/instruments.py:143, in instrument.<locals>.wrapper(*args, **kwargs)
    140 @functools.wraps(fn)
    141 def wrapper(*args, **kwargs):
    142     if not InstrumentSession.active:
--> 143         return fn(*args, **kwargs)
    145     with ResultReporter(stage_name, has_finegrained) as reporter:
    146         fn_results, wall_time, cpu_time = time_function(fn, args, kwargs)

File ~/miniforge3/envs/jax/lib/python3.12/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
     54     s_caller = "::L".join(
     55         [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
     56     )
     57     lgr.debug(
     58         f"Calling {f_string} from {s_caller}",
     59         **_debug_log_kwargs,
     60     )
---> 61 return func(*args, **kwargs)

File ~/miniforge3/envs/jax/lib/python3.12/site-packages/catalyst/jit.py:684, in QJIT.compile(self)
    682 func_name = str(self.mlir_module.body.operations[0].name).replace('"', "")
    683 shared_object, llvm_ir, _ = self.compiler.run(self.mlir_module, self.workspace)
--> 684 compiled_fn = CompiledFunction(
    685     shared_object, func_name, restype, self.out_type, self.compile_options
    686 )
    688 return compiled_fn, llvm_ir

File ~/miniforge3/envs/jax/lib/python3.12/site-packages/pennylane/logging/decorators.py:65, in log_string_debug_func.<locals>.wrapper_exit(*args, **kwargs)
     63 @wraps(func)
     64 def wrapper_exit(*args, **kwargs):
---> 65     output = func(*args, **kwargs)
     66     if lgr.isEnabledFor(log_level):  # pragma: no cover
     67         f_string = _get_bound_signature(*args, **kwargs)

File ~/miniforge3/envs/jax/lib/python3.12/site-packages/catalyst/compiled_functions.py:141, in CompiledFunction.__init__(self, shared_object_file, func_name, restype, out_type, compile_options)
    137 @debug_logger_init
    138 def __init__(
    139     self, shared_object_file, func_name, restype, out_type, compile_options
    140 ):  # pylint: disable=too-many-arguments
--> 141     self.shared_object = SharedObjectManager(shared_object_file, func_name)
    142     self.compile_options = compile_options
    143     self.return_type_c_abi = None

File ~/miniforge3/envs/jax/lib/python3.12/site-packages/pennylane/logging/decorators.py:65, in log_string_debug_func.<locals>.wrapper_exit(*args, **kwargs)
     63 @wraps(func)
     64 def wrapper_exit(*args, **kwargs):
---> 65     output = func(*args, **kwargs)
     66     if lgr.isEnabledFor(log_level):  # pragma: no cover
     67         f_string = _get_bound_signature(*args, **kwargs)

File ~/miniforge3/envs/jax/lib/python3.12/site-packages/catalyst/compiled_functions.py:67, in SharedObjectManager.__init__(self, shared_object_file, func_name)
     65 self.teardown = None
     66 self.mem_transfer = None
---> 67 self.open()

File ~/miniforge3/envs/jax/lib/python3.12/site-packages/catalyst/compiled_functions.py:71, in SharedObjectManager.open(self)
     69 def open(self):
     70     """Open the sharead object and load symbols."""
---> 71     self.shared_object = ctypes.CDLL(self.shared_object_file)
     72     self.function, self.setup, self.teardown, self.mem_transfer = self.load_symbols()

File ~/miniforge3/envs/jax/lib/python3.12/ctypes/__init__.py:379, in CDLL.__init__(self, name, mode, handle, use_errno, use_last_error, winmode)
    376 self._FuncPtr = _FuncPtr
    378 if handle is None:
--> 379     self._handle = _dlopen(self._name, mode)
    380 else:
    381     self._handle = handle

OSError: libgfortran-040039e1.so.5.0.0: cannot open shared object file: No such file or directory

Any help or guidance on how to resolve these issues would be greatly appreciated.

Thank you!

Hi @jracle ,

Thank you for your question.

Could you please post the output of qml.about()?

Note that if you’re on a mac you’ll need to install XCode (more details in the docs).

In any case I’d recommend creating a new virtual environment and run

pip install pennylane-catalyst

Let me know if this solves the issue!

Hi @jracle!

Thank you for giving Catalyst a spin. The easiest thing you can do to get up and running is to let pip manage the JAX installation required by Catalyst. So just try re-installing Catalyst (and making sure it re-installs jaxlib==0.4.23).

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed.

This warning is not a big issue and can be ignored for now, especially for the notebook you are trying to run. It just means that code you run with “pure” JAX won’t run on the GPU (and instead will run on the CPU), but all JAX code inside of qjit decorated functions will be unaffected by this.

For some background, each version of Catalyst pins itself to a specific version of JAX in order to ensure compatibility of the packages (which is very important due the “low-level” that Catalyst is operating at). In this case, the versions are more or less the same (0.4.23), but I suspect the fact that jaxlib is coming from conda is causing issues with system libraries, given the error we see here:

OSError: libgfortran-040039e1.so.5.0.0: cannot open shared object file: No such file or directory

Conda often installs alternative versions of system libraries with can cause conflicts if not managed properly.

Please let me know if my suggestion works for you while we look into the conda issue :slight_smile:

Oh one more thing. If you would really like to get CUDA working together with Catalyst, could you try installing JAX via pip as described here? That is, after installing Catalyst, running:

pip install jax[cuda12]

It should install CUDA support for JAX without changing the jaxlib package version.

( And it would be a good idea to remove the previous packages installed by conda/mamba, or just try this in a clean environment :slight_smile: )

2 Likes

Hi @David_Ittah , I would like to express my gratitude for the help. As suggested, I re-executed

pip install jax[cuda12]

and the dependencies worked correctly.

Additionally, I wanted to ask if Catalyst will support mamba/conda in the future. I always thought using mamba/conda for package management was a good choice, but it seems it may not always be suitable.

Thank you!

Oh great, I’m glad to hear it fixed your issue!

At the moment we have no plans of releasing a conda version of the Catalyst package, but we will keep an eye out to gauge whether this will be a big benefit to users.

Conda users (me included) would benefit from a Catalyst Conda package as Pennylane and several parts of the Jax ecosystem are already included there.

Thanks for your feedback @schance995 !

You should be able to pip install pennylane-catalyst.
You can see more advanced installation instructions here in the docs.

Is there a specific reason why you would prefer to have a Conda package? Note that, as mentioned by David before, Conda often installs alternative versions of system libraries with can cause conflicts if not managed properly.

Let us know in case the pip install or the additional instructions don’t work for you!

The reason is that Catalyst now supports lightning-gpu, but there are no prebuilt lightning-gpu binaries for Linux on Pypi. There are prebuilt binaries for lightning-gpu on Conda Forge and it would be easier as an end-user to to have more Pennylane packages available there.

Hi @schance995,

You can install lightning-gpu with pip install on Linux as follows.
Does this solve your issue? This way you can pip install everything.

pip install custatevec_cu12
pip install pennylane-lightning-gpu

I just tried this and it worked! Not sure why I didn’t receive the binaries last week from pip that time.

It would still be good to have Catalyst inside conda-forge one day to ensure that all binaries are built with dynamically linked libraries from conda-forge. But I understand if that’s not the highest priority right now since Catalyst is still in heavy development.

1 Like