Hello!
I am attempting to use Catalyst to write my code, but I am encountering issues with the compatibility between Catalyst and the CUDA version of jaxlib. Specifically, since Catalyst does not seem to use conda for package management, I first used conda to install jax[cuda12] and other dependencies. Then, I installed Catalyst with pip install pennylane-catalyst
. However, I noticed that the JAX version was modified, and I received the following warning:
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed.
To address this, I tried reinstalling jaxlib using:
mamba install "jaxlib==0.4.23=cuda12*" -c conda-forge
After this, JAX imports normally. However, when I try to import Catalyst, I get the following warning:
/site-packages/catalyst/__init__.py:30: UserWarning: Catalyst detected a version mismatch for the installed 'jaxlib' package. Please make sure to install the exact version required by Catalyst to avoid undefined behavior.
Expected: 0.4.23 Found: 0.4.23.dev20240522
It seems that the CUDA version of jaxlib installed via conda/mamba always includes additional numbers that prevent it from matching correctly. I am unsure how to avoid this issue or if this is expected behavior. Ignoring the warning, I proceeded to run the code from this notebook. However, when I reached the cell input[8]:
loss_fn(params, data, targets)
I encountered the following error:
----> 1 loss_fn(params, data, targets)
File ~/miniforge3/envs/jax/lib/python3.12/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
54 s_caller = "::L".join(
55 [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
56 )
57 lgr.debug(
58 f"Calling {f_string} from {s_caller}",
59 **_debug_log_kwargs,
60 )
---> 61 return func(*args, **kwargs)
File ~/miniforge3/envs/jax/lib/python3.12/site-packages/catalyst/jit.py:499, in QJIT.__call__(self, *args, **kwargs)
496 if EvaluationContext.is_tracing():
497 return self.user_function(*args, **kwargs)
--> 499 requires_promotion = self.jit_compile(args)
501 # If we receive tracers as input, dispatch to the JAX integration.
502 if any(isinstance(arg, jax.core.Tracer) for arg in tree_flatten(args)[0]):
File ~/miniforge3/envs/jax/lib/python3.12/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
54 s_caller = "::L".join(
55 [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
56 )
57 lgr.debug(
58 f"Calling {f_string} from {s_caller}",
59 **_debug_log_kwargs,
60 )
---> 61 return func(*args, **kwargs)
File ~/miniforge3/envs/jax/lib/python3.12/site-packages/catalyst/jit.py:573, in QJIT.jit_compile(self, args)
570 self.jaxpr, self.out_type, self.out_treedef, self.c_sig = self.capture(args)
572 self.mlir_module, self.mlir = self.generate_ir()
--> 573 self.compiled_function, self.qir = self.compile()
575 self.fn_cache.insert(self.compiled_function, args, self.out_treedef, self.workspace)
577 elif self.compiled_function is not cached_fn.compiled_fn:
578 # Restore active state from cache.
File ~/miniforge3/envs/jax/lib/python3.12/site-packages/catalyst/debug/instruments.py:143, in instrument.<locals>.wrapper(*args, **kwargs)
140 @functools.wraps(fn)
141 def wrapper(*args, **kwargs):
142 if not InstrumentSession.active:
--> 143 return fn(*args, **kwargs)
145 with ResultReporter(stage_name, has_finegrained) as reporter:
146 fn_results, wall_time, cpu_time = time_function(fn, args, kwargs)
File ~/miniforge3/envs/jax/lib/python3.12/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
54 s_caller = "::L".join(
55 [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
56 )
57 lgr.debug(
58 f"Calling {f_string} from {s_caller}",
59 **_debug_log_kwargs,
60 )
---> 61 return func(*args, **kwargs)
File ~/miniforge3/envs/jax/lib/python3.12/site-packages/catalyst/jit.py:684, in QJIT.compile(self)
682 func_name = str(self.mlir_module.body.operations[0].name).replace('"', "")
683 shared_object, llvm_ir, _ = self.compiler.run(self.mlir_module, self.workspace)
--> 684 compiled_fn = CompiledFunction(
685 shared_object, func_name, restype, self.out_type, self.compile_options
686 )
688 return compiled_fn, llvm_ir
File ~/miniforge3/envs/jax/lib/python3.12/site-packages/pennylane/logging/decorators.py:65, in log_string_debug_func.<locals>.wrapper_exit(*args, **kwargs)
63 @wraps(func)
64 def wrapper_exit(*args, **kwargs):
---> 65 output = func(*args, **kwargs)
66 if lgr.isEnabledFor(log_level): # pragma: no cover
67 f_string = _get_bound_signature(*args, **kwargs)
File ~/miniforge3/envs/jax/lib/python3.12/site-packages/catalyst/compiled_functions.py:141, in CompiledFunction.__init__(self, shared_object_file, func_name, restype, out_type, compile_options)
137 @debug_logger_init
138 def __init__(
139 self, shared_object_file, func_name, restype, out_type, compile_options
140 ): # pylint: disable=too-many-arguments
--> 141 self.shared_object = SharedObjectManager(shared_object_file, func_name)
142 self.compile_options = compile_options
143 self.return_type_c_abi = None
File ~/miniforge3/envs/jax/lib/python3.12/site-packages/pennylane/logging/decorators.py:65, in log_string_debug_func.<locals>.wrapper_exit(*args, **kwargs)
63 @wraps(func)
64 def wrapper_exit(*args, **kwargs):
---> 65 output = func(*args, **kwargs)
66 if lgr.isEnabledFor(log_level): # pragma: no cover
67 f_string = _get_bound_signature(*args, **kwargs)
File ~/miniforge3/envs/jax/lib/python3.12/site-packages/catalyst/compiled_functions.py:67, in SharedObjectManager.__init__(self, shared_object_file, func_name)
65 self.teardown = None
66 self.mem_transfer = None
---> 67 self.open()
File ~/miniforge3/envs/jax/lib/python3.12/site-packages/catalyst/compiled_functions.py:71, in SharedObjectManager.open(self)
69 def open(self):
70 """Open the sharead object and load symbols."""
---> 71 self.shared_object = ctypes.CDLL(self.shared_object_file)
72 self.function, self.setup, self.teardown, self.mem_transfer = self.load_symbols()
File ~/miniforge3/envs/jax/lib/python3.12/ctypes/__init__.py:379, in CDLL.__init__(self, name, mode, handle, use_errno, use_last_error, winmode)
376 self._FuncPtr = _FuncPtr
378 if handle is None:
--> 379 self._handle = _dlopen(self._name, mode)
380 else:
381 self._handle = handle
OSError: libgfortran-040039e1.so.5.0.0: cannot open shared object file: No such file or directory
Any help or guidance on how to resolve these issues would be greatly appreciated.
Thank you!