Catalyst compilation failure: operand #1 does not dominate this use

Hi! I’m trying out the QML optimization with Jax tutorial: How to optimize a QML model using JAX and Optax | PennyLane Demos

This code snippet is the same as the tutorial, except that I use lightning.qubit and catalyst instead of default.qubit. I also use Pennylane 0.35.0, Catalyst 0.5.0, and an Nvidia GPU. Finally, jit_circuit uses catalyst.vmap to enable broadcasting for lightning.qubit (otherwise a list comprehension is required):

import pennylane as qml
import jax
from jax import numpy as jnp
import catalyst

n_wires = 5
data = jnp.sin(jnp.mgrid[-2:2:0.2].reshape(n_wires, -1)) ** 3
targets = jnp.array([-0.2, 0.4, 0.35, 0.2])

dev_name = "lightning.qubit"
dev = qml.device(dev_name, wires=n_wires)

@qml.qnode(dev)
def circuit(data, weights):
    """Quantum circuit ansatz"""

    for i in range(n_wires):
        qml.RY(data[i], wires=i)

    for i in range(n_wires):
        qml.RX(weights[i, 0], wires=i)
        qml.RY(weights[i, 1], wires=i)
        qml.RX(weights[i, 2], wires=i)
        qml.CNOT(wires=[i, (i + 1) % n_wires])

    return qml.expval(qml.sum(*[qml.PauliZ(i) for i in range(n_wires)]))

# try broadcasting
jit_circuit = catalyst.qjit(catalyst.vmap(circuit, in_axes = (1, None)))

def my_model(data, weights, bias):
    # works with default.qubit
    if dev_name == "default.qubit":
        return circuit(data, weights) + bias

    # works with lightning.qubit, not broadcasted
    # return jnp.array([circuit(jnp.array(d), weights) for d in data.T])

    # only works with loss_fn, fails at grad step
    return jit_circuit(data, weights) + bias

@jax.jit
def loss_fn(params, data, targets):
    predictions = my_model(data, params["weights"], params["bias"])
    loss = jnp.sum((targets - predictions) ** 2 / len(data))
    return loss


weights = jnp.ones([n_wires, 3])
bias = jnp.array(0.)
params = {"weights": weights, "bias": bias}

print(loss_fn(params, data, targets))
print(jax.grad(loss_fn)(params, data, targets))

Computing loss_fn works, but computing jax.grad fails to compile. I observe that the primary error is “operand #1 does not dominate this use”, and that the compiler failed to lower the MLIR module when compiling the gradient pass. But I’m not sure how to resolve this error. Could someone please help debug this? Thanks!

2024-03-06 11:57:10.540238: W external/xla/xla/service/gpu/nvptx_compiler.cc:698] The NVIDIA driver's CUDA version is 12.0 which is older than the ptxas CUDA version (12.3.107). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
0.2923261237897222
Traceback (most recent call last):
  File "/home/schan/miniforge3/envs/quantum-cuda12/lib/python3.11/site-packages/catalyst/compiler.py", line 459, in run_from_ir
    compiler_output = run_compiler_driver(
                      ^^^^^^^^^^^^^^^^^^^^
RuntimeError: Compilation failed:
deriv_batched_fn:8:3: error: operand #1 does not dominate this use
  func.func private @batched_fn(%arg0: tensor<5xf64>, %arg1: tensor<5xf64>, %arg2: tensor<5x4xf64>, %arg3: tensor<5x3xf64>) -> tensor<4xf64> attributes {llvm.linkage = #llvm.linkage<internal>} {
  ^
deriv_batched_fn:8:3: note: see current operation: %7 = "func.call"(%arg1, %21, %arg3) <{callee = @circuit.pcount}> : (tensor<5xf64>, tensor<5xf64>, tensor<5x3xf64>) -> index
deriv_batched_fn:33:13: note: operand defined here (op in a child region)
      %21 = func.call @_take(%arg2, %20) : (tensor<5x4xf64>, tensor<i64>) -> tensor<5xf64>
            ^
While processing 'GradientLoweringPass' pass of the 'QuantumCompilationPass' pipeline
Failed to lower MLIR module


The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/schan/quantum/src/jax/tutorial_How_to_optimize_QML_model_using_JAX_and_Optax.py", line 54, in <module>
    print(loss_fn(params, data, targets))
  File "/home/schan/quantum/src/jax/tutorial_How_to_optimize_QML_model_using_JAX_and_Optax.py", line 45, in loss_fn
    predictions = my_model(data, params["weights"], params["bias"])
  File "/home/schan/quantum/src/jax/tutorial_How_to_optimize_QML_model_using_JAX_and_Optax.py", line 41, in my_model
    return jit_circuit(data, weights) + bias
  File "/home/schan/miniforge3/envs/quantum-cuda12/lib/python3.11/site-packages/catalyst/jit.py", line 114, in __call__
    return self.jaxed_function(*args, **kwargs)
  File "/home/schan/miniforge3/envs/quantum-cuda12/lib/python3.11/site-packages/catalyst/jit.py", line 419, in __call__
    return self.jaxed_function(*args, **kwargs)
jax._src.source_info_util.JaxStackTraceBeforeTransformation: catalyst.utils.exceptions.CompileError: Compilation failed:
deriv_batched_fn:8:3: error: operand #1 does not dominate this use
  func.func private @batched_fn(%arg0: tensor<5xf64>, %arg1: tensor<5xf64>, %arg2: tensor<5x4xf64>, %arg3: tensor<5x3xf64>) -> tensor<4xf64> attributes {llvm.linkage = #llvm.linkage<internal>} {
  ^
deriv_batched_fn:8:3: note: see current operation: %7 = "func.call"(%arg1, %21, %arg3) <{callee = @circuit.pcount}> : (tensor<5xf64>, tensor<5xf64>, tensor<5x3xf64>) -> index
deriv_batched_fn:33:13: note: operand defined here (op in a child region)
      %21 = func.call @_take(%arg2, %20) : (tensor<5x4xf64>, tensor<i64>) -> tensor<5xf64>
            ^
While processing 'GradientLoweringPass' pass of the 'QuantumCompilationPass' pipeline
Failed to lower MLIR module

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/schan/quantum/src/jax/tutorial_How_to_optimize_QML_model_using_JAX_and_Optax.py", line 56, in <module>
    print(jax.grad(loss_fn)(params, data, targets))
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/schan/miniforge3/envs/quantum-cuda12/lib/python3.11/site-packages/catalyst/jit.py", line 395, in compute_jvp
    derivatives = self.wrap_callback(self.get_derivative_qjit(argnums), *primals)
                                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/schan/miniforge3/envs/quantum-cuda12/lib/python3.11/site-packages/catalyst/jit.py", line 375, in get_derivative_qjit
    self.derivative_functions[argnum_key] = QJIT(
                                            ^^^^^
  File "/home/schan/miniforge3/envs/quantum-cuda12/lib/python3.11/site-packages/catalyst/jit.py", line 101, in __init__
    self.aot_compile()
  File "/home/schan/miniforge3/envs/quantum-cuda12/lib/python3.11/site-packages/catalyst/jit.py", line 135, in aot_compile
    self.compiled_function, self.qir = self.compile()
                                       ^^^^^^^^^^^^^^
  File "/home/schan/miniforge3/envs/quantum-cuda12/lib/python3.11/site-packages/catalyst/jit.py", line 271, in compile
    shared_object, llvm_ir, _ = self.compiler.run(self.mlir_module, self.workspace)
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/schan/miniforge3/envs/quantum-cuda12/lib/python3.11/site-packages/catalyst/compiler.py", line 504, in run
    return self.run_from_ir(
           ^^^^^^^^^^^^^^^^^
  File "/home/schan/miniforge3/envs/quantum-cuda12/lib/python3.11/site-packages/catalyst/compiler.py", line 469, in run_from_ir
    raise CompileError(*e.args) from e
catalyst.utils.exceptions.CompileError: Compilation failed:
deriv_batched_fn:8:3: error: operand #1 does not dominate this use
  func.func private @batched_fn(%arg0: tensor<5xf64>, %arg1: tensor<5xf64>, %arg2: tensor<5x4xf64>, %arg3: tensor<5x3xf64>) -> tensor<4xf64> attributes {llvm.linkage = #llvm.linkage<internal>} {
  ^
deriv_batched_fn:8:3: note: see current operation: %7 = "func.call"(%arg1, %21, %arg3) <{callee = @circuit.pcount}> : (tensor<5xf64>, tensor<5xf64>, tensor<5x3xf64>) -> index
deriv_batched_fn:33:13: note: operand defined here (op in a child region)
      %21 = func.call @_take(%arg2, %20) : (tensor<5x4xf64>, tensor<i64>) -> tensor<5xf64>
            ^
While processing 'GradientLoweringPass' pass of the 'QuantumCompilationPass' pipeline
Failed to lower MLIR module

--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

Hey @schance995!

I can’t replicate the behaviour you’re getting :thinking:. In any case, you should try using @qml.qjit in place of @jax.jit! Worth looking through the Catalyst quickstart: Quick Start — Catalyst 0.5.0 documentation. Let me know if that helps!

Hi @isaacdevlugt , what output did you get when running my previous code?

And I tried using qml.qjit with some new code, it’s a little long so I put it in a GitHub gist: Pennylane Catalyst debugging · GitHub

Basically, I define the following for each of jax, catalyst, and the qml wrappers for catalyst:

  1. A quantum circuit
  2. An evaluation function that wraps the circuit with the framework’s vmap
  3. A loss function and its gradient, also with each framework’s functions (jax.jit, catalyst.qjit, qml.qjit and the equivalent grad functions).

Here are the errors I get now:

  1. When no differentiable arguments are used (comment out qml.RY(inputs[i], wires=i) in the circuit), errors occur for gradient functions compiled with catalyst.qjit and qml.qjit:
compiled_grad_loss_fn_cat
Compilation failed:
grad_loss_fn_cat:6:3: error: operand #1 does not dominate this use
  func.func private @loss_fn_cat(%arg0: tensor<2x2xf64>, %arg1: tensor<2x2xf64>, %arg2: tensor<2x2xf64>) -> tensor<f64> attributes {llvm.linkage = #llvm.linkage<internal>} {
  ^
grad_loss_fn_cat:6:3: note: see current operation: %10 = "func.call"(%arg0, %28) <{callee = @circuit.pcount}> : (tensor<2x2xf64>, tensor<2xf64>) -> index
grad_loss_fn_cat:39:13: note: operand defined here (op in the same block)
      %36 = func.call @_take(%arg1, %35) : (tensor<2x2xf64>, tensor<i64>) -> tensor<2xf64>
            ^
While processing 'GradientLoweringPass' pass of the 'QuantumCompilationPass' pipeline
Failed to lower MLIR module

compiled_grad_loss_fn_qml
Compilation failed:
grad_loss_fn_qml:6:3: error: operand #1 does not dominate this use
  func.func private @loss_fn_qml(%arg0: tensor<2x2xf64>, %arg1: tensor<2x2xf64>, %arg2: tensor<2x2xf64>) -> tensor<f64> attributes {llvm.linkage = #llvm.linkage<internal>} {
  ^
grad_loss_fn_qml:6:3: note: see current operation: %10 = "func.call"(%arg0, %28) <{callee = @circuit.pcount}> : (tensor<2x2xf64>, tensor<2xf64>) -> index
grad_loss_fn_qml:39:13: note: operand defined here (op in the same block)
      %36 = func.call @_take(%arg1, %35) : (tensor<2x2xf64>, tensor<i64>) -> tensor<2xf64>
            ^
While processing 'GradientLoweringPass' pass of the 'QuantumCompilationPass' pipeline
Failed to lower MLIR module
  1. When differentiable arguments are used (include qml.RY(inputs[i], wires=i) in the circuit), all circuits using only Jax give this error:
INTERNAL: Generated function failed: CpuCallback error: TypeError: RY(): incompatible function arguments. The following argument types are supported:
    1. (self: pennylane_lightning.lightning_qubit_ops.StateVectorC128, arg0: List[int], arg1: bool, arg2: List[float]) -> None
    2. (self: pennylane_lightning.lightning_qubit_ops.StateVectorC128, arg0: List[int], arg1: List[bool], arg2: List[int], arg3: bool, arg4: List[float]) -> None

Invoked with: <pennylane_lightning.lightning_qubit_ops.StateVectorC128 object at 0x7f8476e219b0>, [0], False, [array([2., 2.])]

These are limitations to using Catalyst with QML workloads that require batching. Are these known bugs, or is something wrong with my code example?

Hi @schance995, I can replicate your errors.

I can see that some of the errors stop showing up when you use the development version of PennyLane. You can try it by running pip install git+https://github.com/PennyLaneAI/pennylane.git#egg=pennylane.

This will become the stable version next week so then you can simply pip install pennylane --upgrade.

I can see that some errors still show up so we’ll have to dig deeper to understand what’s happening here.

Anyways I hope this helps you! Let me know if you have any additional questions or findings.

1 Like

With Pennylane 0.36 and Catalyst 0.6.0, the CpuCallback error has disappeared @CatalinaAlbornoz

But the Failed to lower MLIR module error still persists.

I have also found that running the gist with Python 3.12 gives errors like these, probably related to keyword argument changes:

Exception ignored in: <finalize object at 0x7f5ee8150120; dead>                                  
Traceback (most recent call last):                                                               
  File "/home/schan/miniforge3/envs/py312/lib/python3.12/weakref.py", line 590, in __call__      
    return info.func(*info.args, **(info.kwargs or {}))                                          
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                          
  File "/home/schan/miniforge3/envs/py312/lib/python3.12/site-packages/catalyst/utils/filesystem.
py", line 71, in _cleanup                                                                        
    tempfile.TemporaryDirectory._rmtree(name, **kwargs)                                          
TypeError: TemporaryDirectory._rmtree() got an unexpected keyword argument 'delete'

Hi @schance995, I’m sorry you’re experiencing this issue. There is a bug in the differentiation system of the compiler which we’re working hard on finding a solution for. Hopefully we can share more information or a bugfix release with you soon!

1 Like

I have also found that running the gist with Python 3.12 gives errors like these, probably related to keyword argument changes:

Thanks for reporting this! We implemented a fix in PR #729 which will be available with the next (bugfix) release.

Hi @David_Ittah,

I encountered the same error when computing gradients on: Pennylane 0.39.0 + Catalyst 0.9.0

I found that performing catalyst.vmap() after qml.qjit() no longer causes this issue to be raised. But this probably affects the level of optimization since then it internally dispatches to jax.vmap().

Code to reproduce:

import catalyst
import jax.numpy as jnp
import jax.random
import pennylane as qml
from pennylane import qjit, qnode

N_INPUT = 4

device = qml.device("lightning.qubit", wires=N_INPUT)


@qjit
@catalyst.vmap(in_axes=(0, None), out_axes=0)
@qnode(device, interface="jax")
def circuit(inputs, weights):
    qml.AngleEmbedding(features=inputs, wires=range(N_INPUT), rotation="X")
    for i in range(1, N_INPUT):
        qml.CRX(weights[i - 1], wires=[i, 0])
    return qml.expval(qml.PauliZ(wires=0))


def loss_fn(x, weights, circuit):
    return jnp.sum(circuit(x, weights))


if __name__ == '__main__':
    weights = jax.random.uniform(jax.random.key(0), shape=(N_INPUT - 1,), minval=0, maxval=jnp.pi)
    inputs = jax.random.uniform(jax.random.key(0), shape=(5, N_INPUT - 1))
    val = loss_fn(inputs, weights, circuit)
    print(val.shape)
    grad_fn = catalyst.grad(loss_fn)
    grad = grad_fn(inputs, weights, circuit)
    print(grad.shape)

Workaround?:

...
@catalyst.vmap(in_axes=(0, None), out_axes=0)
@qjit
...
Error

Compilation failed:
deriv_vmap.circuit:12:3: error: operand #0 does not dominate this use
func.func private @vmap.circuit(%arg0: tensor<5x3xf64>, %arg1: tensor<3xf64>) → tensor<5xf64> attributes {llvm.linkage = #llvm.linkage} {
^
deriv_vmap.circuit:12:3: note: see current operation: %8 = “func.call”(%13, %arg1) <{callee = @circuit_0.pcount}> : (tensor<3xf64>, tensor<3xf64>) → index
deriv_vmap.circuit:30:13: note: operand defined here (op in a child region)
%11 = func.call @_take(%arg0, %10) : (tensor<5x3xf64>, tensor) → tensor<3xf64>
^
While processing ‘GradientLoweringPass’ pass Failed to lower MLIR module

Traceback (most recent call last):
File “/home//.local/lib/python3.11/site-packages/catalyst/compiler.py”, line 579, in run_from_ir
compiler_output = run_compiler_driver(
^^^^^^^^^^^^^^^^^^^^
RuntimeError: Compilation failed:
deriv_vmap.circuit:12:3: error: operand #0 does not dominate this use
func.func private @vmap.circuit(%arg0: tensor<5x3xf64>, %arg1: tensor<3xf64>) → tensor<5xf64> attributes {llvm.linkage = #llvm.linkage} {
^
deriv_vmap.circuit:12:3: note: see current operation: %8 = “func.call”(%13, %arg1) <{callee = @circuit_0.pcount}> : (tensor<3xf64>, tensor<3xf64>) → index
deriv_vmap.circuit:30:13: note: operand defined here (op in a child region)
%11 = func.call @_take(%arg0, %10) : (tensor<5x3xf64>, tensor) → tensor<3xf64>
^
While processing ‘GradientLoweringPass’ pass Failed to lower MLIR module

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
File “catalyst_query.py”, line 32, in
grad = grad_fn(inputs, weights, circuit)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/home//.local/lib/python3.11/site-packages/catalyst/api_extensions/differentiation.py”, line 705, in call
results = jax.grad(self.fn, argnums=argnums)(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “catalyst_query.py”, line 23, in loss_fn
return jnp.sum(circuit(x, weights))
^^^^^^^^^^^^^^^^^^^
File “/home//.local/lib/python3.11/site-packages/pennylane/logging/decorators.py”, line 61, in wrapper_entry
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File “/home//.local/lib/python3.11/site-packages/catalyst/jit.py”, line 533, in call
return self.jaxed_function(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/home//.local/lib/python3.11/site-packages/pennylane/logging/decorators.py”, line 61, in wrapper_entry
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File “/home//.local/lib/python3.11/site-packages/catalyst/jit.py”, line 910, in call
return self.jaxed_function(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/home//.local/lib/python3.11/site-packages/pennylane/logging/decorators.py”, line 61, in wrapper_entry
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File “/home//.local/lib/python3.11/site-packages/catalyst/jit.py”, line 885, in compute_jvp
derivatives = self.wrap_callback(self.get_derivative_qjit(argnums), *primals)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/home//.local/lib/python3.11/site-packages/pennylane/logging/decorators.py”, line 61, in wrapper_entry
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File “/home//.local/lib/python3.11/site-packages/catalyst/jit.py”, line 864, in get_derivative_qjit
self.derivative_functions[argnum_key] = QJIT(
^^^^^
File “/home//.local/lib/python3.11/site-packages/pennylane/logging/decorators.py”, line 65, in wrapper_exit
output = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File “/home//.local/lib/python3.11/site-packages/catalyst/jit.py”, line 513, in init
self.aot_compile()
File “/home//.local/lib/python3.11/site-packages/pennylane/logging/decorators.py”, line 61, in wrapper_entry
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File “/home//.local/lib/python3.11/site-packages/catalyst/jit.py”, line 561, in aot_compile
self.compiled_function, self.qir = self.compile()
^^^^^^^^^^^^^^
File “/home//.local/lib/python3.11/site-packages/catalyst/debug/instruments.py”, line 143, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File “/home//.local/lib/python3.11/site-packages/pennylane/logging/decorators.py”, line 61, in wrapper_entry
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File “/home//.local/lib/python3.11/site-packages/catalyst/jit.py”, line 755, in compile
shared_object, llvm_ir = self.compiler.run(self.mlir_module, self.workspace)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/home//.local/lib/python3.11/site-packages/pennylane/logging/decorators.py”, line 61, in wrapper_entry
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File “/home//.local/lib/python3.11/site-packages/catalyst/compiler.py”, line 624, in run
return self.run_from_ir(
^^^^^^^^^^^^^^^^^
File “/home//.local/lib/python3.11/site-packages/pennylane/logging/decorators.py”, line 61, in wrapper_entry
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File “/home//.local/lib/python3.11/site-packages/catalyst/compiler.py”, line 591, in run_from_ir
raise CompileError(*e.args) from e
catalyst.utils.exceptions.CompileError: Compilation failed:
deriv_vmap.circuit:12:3: error: operand #0 does not dominate this use
func.func private @vmap.circuit(%arg0: tensor<5x3xf64>, %arg1: tensor<3xf64>) → tensor<5xf64> attributes {llvm.linkage = #llvm.linkage} {
^
deriv_vmap.circuit:12:3: note: see current operation: %8 = “func.call”(%13, %arg1) <{callee = @circuit_0.pcount}> : (tensor<3xf64>, tensor<3xf64>) → index
deriv_vmap.circuit:30:13: note: operand defined here (op in a child region)
%11 = func.call @_take(%arg0, %10) : (tensor<5x3xf64>, tensor) → tensor<3xf64>
^
While processing ‘GradientLoweringPass’ pass Failed to lower MLIR module


For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

qml.about()

Name: PennyLane
Version: 0.39.0
Summary: PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
Home-page: GitHub - PennyLaneAI/pennylane: PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
Author:
Author-email:
License: Apache License 2.0
Location: /home//.local/lib/python3.11/site-packages
Requires: appdirs, autograd, autoray, cachetools, networkx, numpy, packaging, pennylane-lightning, requests, rustworkx, scipy, toml, typing-extensions
Required-by: PennyLane-Catalyst, PennyLane_Lightning, PennyLane_Lightning_GPU

Platform info: Linux-5.15.167.4-microsoft-standard-WSL2-x86_64-with-glibc2.35
Python version: 3.11.9
Numpy version: 1.26.4
Scipy version: 1.13.0
Installed devices:

  • lightning.qubit (PennyLane_Lightning-0.39.0)
  • nvidia.custatevec (PennyLane-Catalyst-0.9.0)
  • nvidia.cutensornet (PennyLane-Catalyst-0.9.0)
  • oqc.cloud (PennyLane-Catalyst-0.9.0)
  • softwareq.qpp (PennyLane-Catalyst-0.9.0)
  • default.clifford (PennyLane-0.39.0)
  • default.gaussian (PennyLane-0.39.0)
  • default.mixed (PennyLane-0.39.0)
  • default.qubit (PennyLane-0.39.0)
  • default.qutrit (PennyLane-0.39.0)
  • default.qutrit.mixed (PennyLane-0.39.0)
  • default.tensor (PennyLane-0.39.0)
  • null.qubit (PennyLane-0.39.0)
  • reference.qubit (PennyLane-0.39.0)
  • lightning.gpu (PennyLane_Lightning_GPU-0.39.0)

Another much faster workaround for now is to perform both vmap and value_and_grad inside qjit which works, followed by manually using something like jax.custom_vjp to tie it into the framework’s gradient calculation.

@qjit
@catalyst.vmap(in_axes=(0, None), out_axes=(0, (0, 0)))
@catalyst.value_and_grad(argnums=(0, 1))
@qnode(device, interface="jax")
def circuit(inputs, weights):
    ...

Thanks for sharing your workarounds @LasradoRohan, that’s really helpful! Indeed there is still an issue regarding the differentiation of a vmap’d function. We believe there may be an problem in the differentiation engine we are using (Enzyme), and we working with the developers to come up with a fix.