Catalyst compilation failure: operand #1 does not dominate this use

Hi! I’m trying out the QML optimization with Jax tutorial: How to optimize a QML model using JAX and Optax | PennyLane Demos

This code snippet is the same as the tutorial, except that I use lightning.qubit and catalyst instead of default.qubit. I also use Pennylane 0.35.0, Catalyst 0.5.0, and an Nvidia GPU. Finally, jit_circuit uses catalyst.vmap to enable broadcasting for lightning.qubit (otherwise a list comprehension is required):

import pennylane as qml
import jax
from jax import numpy as jnp
import catalyst

n_wires = 5
data = jnp.sin(jnp.mgrid[-2:2:0.2].reshape(n_wires, -1)) ** 3
targets = jnp.array([-0.2, 0.4, 0.35, 0.2])

dev_name = "lightning.qubit"
dev = qml.device(dev_name, wires=n_wires)

@qml.qnode(dev)
def circuit(data, weights):
    """Quantum circuit ansatz"""

    for i in range(n_wires):
        qml.RY(data[i], wires=i)

    for i in range(n_wires):
        qml.RX(weights[i, 0], wires=i)
        qml.RY(weights[i, 1], wires=i)
        qml.RX(weights[i, 2], wires=i)
        qml.CNOT(wires=[i, (i + 1) % n_wires])

    return qml.expval(qml.sum(*[qml.PauliZ(i) for i in range(n_wires)]))

# try broadcasting
jit_circuit = catalyst.qjit(catalyst.vmap(circuit, in_axes = (1, None)))

def my_model(data, weights, bias):
    # works with default.qubit
    if dev_name == "default.qubit":
        return circuit(data, weights) + bias

    # works with lightning.qubit, not broadcasted
    # return jnp.array([circuit(jnp.array(d), weights) for d in data.T])

    # only works with loss_fn, fails at grad step
    return jit_circuit(data, weights) + bias

@jax.jit
def loss_fn(params, data, targets):
    predictions = my_model(data, params["weights"], params["bias"])
    loss = jnp.sum((targets - predictions) ** 2 / len(data))
    return loss


weights = jnp.ones([n_wires, 3])
bias = jnp.array(0.)
params = {"weights": weights, "bias": bias}

print(loss_fn(params, data, targets))
print(jax.grad(loss_fn)(params, data, targets))

Computing loss_fn works, but computing jax.grad fails to compile. I observe that the primary error is “operand #1 does not dominate this use”, and that the compiler failed to lower the MLIR module when compiling the gradient pass. But I’m not sure how to resolve this error. Could someone please help debug this? Thanks!

2024-03-06 11:57:10.540238: W external/xla/xla/service/gpu/nvptx_compiler.cc:698] The NVIDIA driver's CUDA version is 12.0 which is older than the ptxas CUDA version (12.3.107). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
0.2923261237897222
Traceback (most recent call last):
  File "/home/schan/miniforge3/envs/quantum-cuda12/lib/python3.11/site-packages/catalyst/compiler.py", line 459, in run_from_ir
    compiler_output = run_compiler_driver(
                      ^^^^^^^^^^^^^^^^^^^^
RuntimeError: Compilation failed:
deriv_batched_fn:8:3: error: operand #1 does not dominate this use
  func.func private @batched_fn(%arg0: tensor<5xf64>, %arg1: tensor<5xf64>, %arg2: tensor<5x4xf64>, %arg3: tensor<5x3xf64>) -> tensor<4xf64> attributes {llvm.linkage = #llvm.linkage<internal>} {
  ^
deriv_batched_fn:8:3: note: see current operation: %7 = "func.call"(%arg1, %21, %arg3) <{callee = @circuit.pcount}> : (tensor<5xf64>, tensor<5xf64>, tensor<5x3xf64>) -> index
deriv_batched_fn:33:13: note: operand defined here (op in a child region)
      %21 = func.call @_take(%arg2, %20) : (tensor<5x4xf64>, tensor<i64>) -> tensor<5xf64>
            ^
While processing 'GradientLoweringPass' pass of the 'QuantumCompilationPass' pipeline
Failed to lower MLIR module


The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/schan/quantum/src/jax/tutorial_How_to_optimize_QML_model_using_JAX_and_Optax.py", line 54, in <module>
    print(loss_fn(params, data, targets))
  File "/home/schan/quantum/src/jax/tutorial_How_to_optimize_QML_model_using_JAX_and_Optax.py", line 45, in loss_fn
    predictions = my_model(data, params["weights"], params["bias"])
  File "/home/schan/quantum/src/jax/tutorial_How_to_optimize_QML_model_using_JAX_and_Optax.py", line 41, in my_model
    return jit_circuit(data, weights) + bias
  File "/home/schan/miniforge3/envs/quantum-cuda12/lib/python3.11/site-packages/catalyst/jit.py", line 114, in __call__
    return self.jaxed_function(*args, **kwargs)
  File "/home/schan/miniforge3/envs/quantum-cuda12/lib/python3.11/site-packages/catalyst/jit.py", line 419, in __call__
    return self.jaxed_function(*args, **kwargs)
jax._src.source_info_util.JaxStackTraceBeforeTransformation: catalyst.utils.exceptions.CompileError: Compilation failed:
deriv_batched_fn:8:3: error: operand #1 does not dominate this use
  func.func private @batched_fn(%arg0: tensor<5xf64>, %arg1: tensor<5xf64>, %arg2: tensor<5x4xf64>, %arg3: tensor<5x3xf64>) -> tensor<4xf64> attributes {llvm.linkage = #llvm.linkage<internal>} {
  ^
deriv_batched_fn:8:3: note: see current operation: %7 = "func.call"(%arg1, %21, %arg3) <{callee = @circuit.pcount}> : (tensor<5xf64>, tensor<5xf64>, tensor<5x3xf64>) -> index
deriv_batched_fn:33:13: note: operand defined here (op in a child region)
      %21 = func.call @_take(%arg2, %20) : (tensor<5x4xf64>, tensor<i64>) -> tensor<5xf64>
            ^
While processing 'GradientLoweringPass' pass of the 'QuantumCompilationPass' pipeline
Failed to lower MLIR module

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/schan/quantum/src/jax/tutorial_How_to_optimize_QML_model_using_JAX_and_Optax.py", line 56, in <module>
    print(jax.grad(loss_fn)(params, data, targets))
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/schan/miniforge3/envs/quantum-cuda12/lib/python3.11/site-packages/catalyst/jit.py", line 395, in compute_jvp
    derivatives = self.wrap_callback(self.get_derivative_qjit(argnums), *primals)
                                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/schan/miniforge3/envs/quantum-cuda12/lib/python3.11/site-packages/catalyst/jit.py", line 375, in get_derivative_qjit
    self.derivative_functions[argnum_key] = QJIT(
                                            ^^^^^
  File "/home/schan/miniforge3/envs/quantum-cuda12/lib/python3.11/site-packages/catalyst/jit.py", line 101, in __init__
    self.aot_compile()
  File "/home/schan/miniforge3/envs/quantum-cuda12/lib/python3.11/site-packages/catalyst/jit.py", line 135, in aot_compile
    self.compiled_function, self.qir = self.compile()
                                       ^^^^^^^^^^^^^^
  File "/home/schan/miniforge3/envs/quantum-cuda12/lib/python3.11/site-packages/catalyst/jit.py", line 271, in compile
    shared_object, llvm_ir, _ = self.compiler.run(self.mlir_module, self.workspace)
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/schan/miniforge3/envs/quantum-cuda12/lib/python3.11/site-packages/catalyst/compiler.py", line 504, in run
    return self.run_from_ir(
           ^^^^^^^^^^^^^^^^^
  File "/home/schan/miniforge3/envs/quantum-cuda12/lib/python3.11/site-packages/catalyst/compiler.py", line 469, in run_from_ir
    raise CompileError(*e.args) from e
catalyst.utils.exceptions.CompileError: Compilation failed:
deriv_batched_fn:8:3: error: operand #1 does not dominate this use
  func.func private @batched_fn(%arg0: tensor<5xf64>, %arg1: tensor<5xf64>, %arg2: tensor<5x4xf64>, %arg3: tensor<5x3xf64>) -> tensor<4xf64> attributes {llvm.linkage = #llvm.linkage<internal>} {
  ^
deriv_batched_fn:8:3: note: see current operation: %7 = "func.call"(%arg1, %21, %arg3) <{callee = @circuit.pcount}> : (tensor<5xf64>, tensor<5xf64>, tensor<5x3xf64>) -> index
deriv_batched_fn:33:13: note: operand defined here (op in a child region)
      %21 = func.call @_take(%arg2, %20) : (tensor<5x4xf64>, tensor<i64>) -> tensor<5xf64>
            ^
While processing 'GradientLoweringPass' pass of the 'QuantumCompilationPass' pipeline
Failed to lower MLIR module

--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

Hey @schance995!

I can’t replicate the behaviour you’re getting :thinking:. In any case, you should try using @qml.qjit in place of @jax.jit! Worth looking through the Catalyst quickstart: Quick Start — Catalyst 0.5.0 documentation. Let me know if that helps!