Compiler crash while computing circuit gradient (qjit and vmap) with lightning.qubit and lightning.kokkos

Hello there,

I have an issue computing the ‘jit-ed’ and ‘vmap-ed’ gradient of a circuit. It seems that the compiler is crashing for some reason I do not understand. A minimal example and some further information are written below. Does anybody have an idea how to resolve this issue?

########################################
Name: PennyLane
Version: 0.41.1
Summary: PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
Home-page: GitHub - PennyLaneAI/pennylane: PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Built by researchers, for research.
Author:
Author-email:
License: Apache License 2.0
Location: /opt/conda/lib/python3.11/site-packages
Requires: appdirs, autograd, autoray, cachetools, diastatic-malt, networkx, numpy, packaging, pennylane-lightning, requests, rustworkx, scipy, tomlkit, typing-extensions
Required-by: PennyLane-Catalyst, PennyLane_Lightning, PennyLane_Lightning_GPU, PennyLane_Lightning_Kokkos

Platform info: Linux-5.15.0-142-generic-x86_64-with-glibc2.35
Python version: 3.11.10
Numpy version: 2.3.1
Scipy version: 1.16.0
Installed devices:

  • lightning.gpu (PennyLane_Lightning_GPU-0.41.1)
  • default.clifford (PennyLane-0.41.1)
  • default.gaussian (PennyLane-0.41.1)
  • default.mixed (PennyLane-0.41.1)
  • default.qubit (PennyLane-0.41.1)
  • default.qutrit (PennyLane-0.41.1)
  • default.qutrit.mixed (PennyLane-0.41.1)
  • default.tensor (PennyLane-0.41.1)
  • null.qubit (PennyLane-0.41.1)
  • reference.qubit (PennyLane-0.41.1)
  • lightning.qubit (PennyLane_Lightning-0.41.1)
  • nvidia.custatevec (PennyLane-Catalyst-0.11.0)
  • nvidia.cutensornet (PennyLane-Catalyst-0.11.0)
  • oqc.cloud (PennyLane-Catalyst-0.11.0)
  • softwareq.qpp (PennyLane-Catalyst-0.11.0)
  • lightning.kokkos (PennyLane_Lightning_Kokkos-0.41.1)

#######################################

Minimal example:

def test_catalyst_vmap_qjit():
    “”“Test Catalyst vmap with qjit (realistic usage)”“”
    print(“\nTesting Catalyst vmap + qjit…”)
  try:
    import catalyst
    import pennylane as qml
    import jax.numpy as jnp

    dev = qml.device("lightning.qubit", wires=2)
    
    @qml.qnode(dev, interface='jax')
        def circuit(x,w):
            qml.RX(x, wires=0)
            qml.RY(w, wires=1)

            return qml.expval(qml.PauliZ(0))
        
    # Test vmap with qjit (the way it's actually used)
    batch_circuit = catalyst.vmap(circuit, in_axes=(0, None))
    compiled_circuit = qml.qjit(batch_circuit)
        

    batch_input = jnp.array([0.1, 0.2, 0.3])
    w=jnp.array([0.5])
    results = compiled_circuit(batch_input,w)
    print(f"✓ Catalyst vmap + qjit working: batch results = {results}")
    
    grad_compiled = catalyst.grad(compiled_circuit)
    results = grad_compiled(batch_input)
    print(f"✓ Catalyst gradient vmap + qjit  working: batch results = {results}")
    return True
  except Exception as e:
    print(f"✗ Catalyst vmap + qjit failed: {e}")
    return False


if __name__ == “__main__”:
test_catalyst_vmap_qjit()

############################################

Output:

Testing Catalyst vmap + qjit…
2025-07-17 09:07:09.765846: W external/xla/xla/service/gpu/nvptx_compiler.cc:760] The NVIDIA driver’s CUDA version is 12.4 which is older than the ptxas CUDA version (12.9.86). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
✓ Catalyst vmap + qjit working: batch results = [0.99500417 0.98006658 0.95533649]
✗ Catalyst vmap + qjit failed: catalyst failed with error code -6: catalyst: /__w/catalyst/catalyst/mlir/llvm-project/mlir/lib/Analysis/Liveness.cpp:45: {anonymous}::BlockInfoBuilder::BlockInfoBuilder(mlir::Block*)::<lambda(mlir::Value)>: Assertion ownerBlock && "Use leaves the current parent region"' failed. PLEASE submit a bug report to https://github.com/llvm/llvm-project/issues/ and include the crash backtrace. Stack dump: 0. Program arguments: /opt/conda/bin/catalyst -o /tmp/deriv_vmap.circuitzredee5n/deriv_vmap.circuit.ll --module-name deriv_vmap.circuit --workspace /tmp/deriv_vmap.circuitzredee5n -verify-each=false --catalyst-pipeline EnforceRuntimeInvariantsPass(split-multiple-tapes;builtin.module(apply-transform-sequence);inline-nested-module),HLOLoweringPass(canonicalize;func.func(chlo-legalize-to-hlo);stablehlo-legalize-to-hlo;func.func(mhlo-legalize-control-flow);func.func(hlo-legalize-to-linalg);func.func(mhlo-legalize-to-std);func.func(hlo-legalize-sort);convert-to-signless;canonicalize;scatter-lowering;hlo-custom-call-lowering;cse;func.func(linalg-detensorize{aggressive-mode});detensorize-scf;canonicalize),QuantumCompilationPass(annotate-function;lower-mitigation;lower-gradients;adjoint-lowering),BufferizationPass(one-shot-bufferize{dialect-filter=memref};inline;gradient-preprocess;gradient-bufferize;scf-bufferize;convert-tensor-to-linalg;convert-elementwise-to-linalg;arith-bufferize;empty-tensor-to-alloc-tensor;func.func(bufferization-bufferize);func.func(tensor-bufferize);catalyst-bufferize;func.func(linalg-bufferize);func.func(tensor-bufferize);quantum-bufferize;func-bufferize;func.func(finalizing-bufferize);canonicalize;gradient-postprocess;func.func(buffer-hoisting);func.func(buffer-loop-hoisting);func.func(buffer-deallocation);convert-arraylist-to-memref;convert-bufferization-to-memref;canonicalize;cp-global-memref),MLIRToLLVMDialect(expand-realloc;convert-gradient-to-llvm;memrefcpy-to-linalgcpy;func.func(convert-linalg-to-loops);convert-scf-to-cf;expand-strided-metadata;lower-affine;arith-expand;convert-complex-to-standard;convert-complex-to-llvm;convert-math-to-llvm;convert-math-to-libm;convert-arith-to-llvm;memref-to-llvm-tbaa;finalize-memref-to-llvm{use-generic-functions};convert-index-to-llvm;convert-catalyst-to-llvm;convert-quantum-to-llvm;emit-catalyst-py-interface;canonicalize;reconcile-unrealized-casts;gep-inbounds;register-inactive-callback), /tmp/deriv_vmap.circuitzredee5n/tmpk2xja2xe.mlir Stack dump without symbol names (ensure you have llvm-symbolizer in your PATH or set the environment var LLVM_SYMBOLIZER_PATH` to point to it):
0 catalyst 0x00000000098d62ab llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) + 43
1 catalyst 0x00000000098d36cb llvm::sys::RunSignalHandlers() + 43
2 catalyst 0x00000000098d37f5
3 libc.so.6 0x00007fa8625b1520
4 libc.so.6 0x00007fa8626059fc pthread_kill + 300
5 libc.so.6 0x00007fa8625b1476 raise + 22
6 libc.so.6 0x00007fa8625977f3 abort + 211
7 libc.so.6 0x00007fa86259771b
8 libc.so.6 0x00007fa8625a8e96
9 catalyst 0x0000000009562997
10 catalyst 0x0000000009561f0a
11 catalyst 0x000000000956d2ee mlir::Liveness::build() + 174
12 catalyst 0x000000000513f0b9
13 catalyst 0x000000000952c5ee mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int) + 1038
14 catalyst 0x000000000952caa8 mlir::detail::OpToOpPassAdaptor::runPipeline(mlir::OpPassManager&, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int, mlir::PassInstrumentor*, mlir::PassInstrumentation::PipelineParentInfo const*) + 344
15 catalyst 0x000000000952d25d mlir::detail::OpToOpPassAdaptor::runOnOperationImpl(bool) + 461
16 catalyst 0x000000000952c42a mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int) + 586
17 catalyst 0x000000000952caa8 mlir::detail::OpToOpPassAdaptor::runPipeline(mlir::OpPassManager&, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int, mlir::PassInstrumentor*, mlir::PassInstrumentation::PipelineParentInfo const*) + 344
18 catalyst 0x000000000952da05 mlir::PassManager::run(mlir::Operation*) + 1205
19 catalyst 0x00000000036012c3 runPipeline(mlir::PassManager&, catalyst::driver::CompilerOptions const&, catalyst::driver::CompilerOutput&, catalyst::driver::Pipeline&, bool, mlir::ModuleOp) + 259
20 catalyst 0x000000000360172d runLowering(catalyst::driver::CompilerOptions const&, mlir::MLIRContext*, mlir::ModuleOp, catalyst::driver::CompilerOutput&, mlir::TimingScope&) + 445
21 catalyst 0x0000000003603fe4 QuantumDriverMain(catalyst::driver::CompilerOptions const&, catalyst::driver::CompilerOutput&, mlir::DialectRegistry&) + 6676
22 catalyst 0x0000000003608a0b QuantumDriverMainFromCL(int, char**) + 10395
23 libc.so.6 0x00007fa862598d90
24 libc.so.6 0x00007fa862598e40 __libc_start_main + 128
25 catalyst 0x00000000035e012e _start + 46

Thank you very much @minn-bj for reporting this. Please allow some time for our team to look into it :slight_smile:

Hey David,
I already found the issue. Even though using catalyst.grad is totally fine, using catalyst.qjit leads to the error. Instead one should use qml.qjit which works fine for me. A slightly modified version of the code can be find below
Thanks anyway :slight_smile:

Modified Code:

"""Test Catalyst vmap with qjit (realistic usage)"""
print("\nTesting Catalyst vmap + qjit...")
try:
     import catalyst
     import pennylane as qml
     import jax.numpy as jnp
    
    dev = qml.device("lightning.qubit", wires=2)
    
    @qml.qnode(dev, interface='jax')
    def circuit(x,w):
        qml.RX(x, wires=0)
        qml.RY(w, wires=1)

        return qml.expval(qml.PauliZ(0))
    
    # Test vmap with qjit (the way it's actually used)
    batch_circuit = catalyst.vmap(circuit, in_axes=(0, None))
    compiled_circuit = qml.qjit(batch_circuit)
    
    batch_input = jnp.array([0.1, 0.2, 0.3])
    w=jnp.array([0.5])
    results = compiled_circuit(batch_input,w)
    print(f"✓ Catalyst vmap + qjit working: batch results = {results}")
    
    def loss(w,x):
        return jnp.average((compiled_circuit(x,w)-x)**2)
    grad_compiled = qml.qjit(catalyst.grad(loss, method="fd"))
    results = grad_compiled(w, batch_input)
    print(f"✓ Catalyst gradient vmap + qjit  working: batch results = {results}")
    return True
except Exception as e:
    print(f"✗ Catalyst vmap + qjit failed: {e}")
    return False

Fantastic, I’m glad you found a solution!

Note that qml.qjit and catalyst.qjit run the same code. I believe the reason your example works now is because you are using catalyst.grad inside your qjitted function, which is the most common use case. Using it outside of the compiled program dispatches to the Catalyst-JAX integration which uses a different AD function under the hood.

So I believe you have in fact discovered a bug in the Catalyst-JAX integration, which we’ll look into :slight_smile:

EDIT: Just noticed the original example also had the issue that it was trying to compute the gradient on a non-scalar function.

EDIT2: I looked at the working circuit again, my initial assessment was wrong. It is not the fact of using catalyst.grad inside the qjitted function that fixed the issue, but the use of method="fd", which is a much simpler and robust implementation that will work on almost any program. The only downside is that it can be less accurate, numerically less stable, and possibly more expensive depending on the number of parameters vs results.

1 Like

Thank you for the additional information. Are there any plans to extend the ‘method’ such that other diff. methods can be used?

Do you have any specific ones in mind?

Currently the other option in the catalyst.grad function is "auto", which uses backpropagation for classical code and then whatever is specified in the QNode under diff_method for quantum code (default should be "adjoint" in Catalyst, but "parameter-shift" is also supported).