Hi! I’m trying out the QML optimization with Jax tutorial: How to optimize a QML model using JAX and Optax | PennyLane Demos
This code snippet is the same as the tutorial, except that I use lightning.qubit and catalyst instead of default.qubit. I also use Pennylane 0.35.0, Catalyst 0.5.0, and an Nvidia GPU. Finally, jit_circuit uses catalyst.vmap to enable broadcasting for lightning.qubit (otherwise a list comprehension is required):
import pennylane as qml
import jax
from jax import numpy as jnp
import catalyst
n_wires = 5
data = jnp.sin(jnp.mgrid[-2:2:0.2].reshape(n_wires, -1)) ** 3
targets = jnp.array([-0.2, 0.4, 0.35, 0.2])
dev_name = "lightning.qubit"
dev = qml.device(dev_name, wires=n_wires)
@qml.qnode(dev)
def circuit(data, weights):
"""Quantum circuit ansatz"""
for i in range(n_wires):
qml.RY(data[i], wires=i)
for i in range(n_wires):
qml.RX(weights[i, 0], wires=i)
qml.RY(weights[i, 1], wires=i)
qml.RX(weights[i, 2], wires=i)
qml.CNOT(wires=[i, (i + 1) % n_wires])
return qml.expval(qml.sum(*[qml.PauliZ(i) for i in range(n_wires)]))
# try broadcasting
jit_circuit = catalyst.qjit(catalyst.vmap(circuit, in_axes = (1, None)))
def my_model(data, weights, bias):
# works with default.qubit
if dev_name == "default.qubit":
return circuit(data, weights) + bias
# works with lightning.qubit, not broadcasted
# return jnp.array([circuit(jnp.array(d), weights) for d in data.T])
# only works with loss_fn, fails at grad step
return jit_circuit(data, weights) + bias
@jax.jit
def loss_fn(params, data, targets):
predictions = my_model(data, params["weights"], params["bias"])
loss = jnp.sum((targets - predictions) ** 2 / len(data))
return loss
weights = jnp.ones([n_wires, 3])
bias = jnp.array(0.)
params = {"weights": weights, "bias": bias}
print(loss_fn(params, data, targets))
print(jax.grad(loss_fn)(params, data, targets))
Computing loss_fn works, but computing jax.grad fails to compile. I observe that the primary error is “operand #1 does not dominate this use”, and that the compiler failed to lower the MLIR module when compiling the gradient pass. But I’m not sure how to resolve this error. Could someone please help debug this? Thanks!
2024-03-06 11:57:10.540238: W external/xla/xla/service/gpu/nvptx_compiler.cc:698] The NVIDIA driver's CUDA version is 12.0 which is older than the ptxas CUDA version (12.3.107). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
0.2923261237897222
Traceback (most recent call last):
File "/home/schan/miniforge3/envs/quantum-cuda12/lib/python3.11/site-packages/catalyst/compiler.py", line 459, in run_from_ir
compiler_output = run_compiler_driver(
^^^^^^^^^^^^^^^^^^^^
RuntimeError: Compilation failed:
deriv_batched_fn:8:3: error: operand #1 does not dominate this use
func.func private @batched_fn(%arg0: tensor<5xf64>, %arg1: tensor<5xf64>, %arg2: tensor<5x4xf64>, %arg3: tensor<5x3xf64>) -> tensor<4xf64> attributes {llvm.linkage = #llvm.linkage<internal>} {
^
deriv_batched_fn:8:3: note: see current operation: %7 = "func.call"(%arg1, %21, %arg3) <{callee = @circuit.pcount}> : (tensor<5xf64>, tensor<5xf64>, tensor<5x3xf64>) -> index
deriv_batched_fn:33:13: note: operand defined here (op in a child region)
%21 = func.call @_take(%arg2, %20) : (tensor<5x4xf64>, tensor<i64>) -> tensor<5xf64>
^
While processing 'GradientLoweringPass' pass of the 'QuantumCompilationPass' pipeline
Failed to lower MLIR module
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/schan/quantum/src/jax/tutorial_How_to_optimize_QML_model_using_JAX_and_Optax.py", line 54, in <module>
print(loss_fn(params, data, targets))
File "/home/schan/quantum/src/jax/tutorial_How_to_optimize_QML_model_using_JAX_and_Optax.py", line 45, in loss_fn
predictions = my_model(data, params["weights"], params["bias"])
File "/home/schan/quantum/src/jax/tutorial_How_to_optimize_QML_model_using_JAX_and_Optax.py", line 41, in my_model
return jit_circuit(data, weights) + bias
File "/home/schan/miniforge3/envs/quantum-cuda12/lib/python3.11/site-packages/catalyst/jit.py", line 114, in __call__
return self.jaxed_function(*args, **kwargs)
File "/home/schan/miniforge3/envs/quantum-cuda12/lib/python3.11/site-packages/catalyst/jit.py", line 419, in __call__
return self.jaxed_function(*args, **kwargs)
jax._src.source_info_util.JaxStackTraceBeforeTransformation: catalyst.utils.exceptions.CompileError: Compilation failed:
deriv_batched_fn:8:3: error: operand #1 does not dominate this use
func.func private @batched_fn(%arg0: tensor<5xf64>, %arg1: tensor<5xf64>, %arg2: tensor<5x4xf64>, %arg3: tensor<5x3xf64>) -> tensor<4xf64> attributes {llvm.linkage = #llvm.linkage<internal>} {
^
deriv_batched_fn:8:3: note: see current operation: %7 = "func.call"(%arg1, %21, %arg3) <{callee = @circuit.pcount}> : (tensor<5xf64>, tensor<5xf64>, tensor<5x3xf64>) -> index
deriv_batched_fn:33:13: note: operand defined here (op in a child region)
%21 = func.call @_take(%arg2, %20) : (tensor<5x4xf64>, tensor<i64>) -> tensor<5xf64>
^
While processing 'GradientLoweringPass' pass of the 'QuantumCompilationPass' pipeline
Failed to lower MLIR module
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/schan/quantum/src/jax/tutorial_How_to_optimize_QML_model_using_JAX_and_Optax.py", line 56, in <module>
print(jax.grad(loss_fn)(params, data, targets))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/schan/miniforge3/envs/quantum-cuda12/lib/python3.11/site-packages/catalyst/jit.py", line 395, in compute_jvp
derivatives = self.wrap_callback(self.get_derivative_qjit(argnums), *primals)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/schan/miniforge3/envs/quantum-cuda12/lib/python3.11/site-packages/catalyst/jit.py", line 375, in get_derivative_qjit
self.derivative_functions[argnum_key] = QJIT(
^^^^^
File "/home/schan/miniforge3/envs/quantum-cuda12/lib/python3.11/site-packages/catalyst/jit.py", line 101, in __init__
self.aot_compile()
File "/home/schan/miniforge3/envs/quantum-cuda12/lib/python3.11/site-packages/catalyst/jit.py", line 135, in aot_compile
self.compiled_function, self.qir = self.compile()
^^^^^^^^^^^^^^
File "/home/schan/miniforge3/envs/quantum-cuda12/lib/python3.11/site-packages/catalyst/jit.py", line 271, in compile
shared_object, llvm_ir, _ = self.compiler.run(self.mlir_module, self.workspace)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/schan/miniforge3/envs/quantum-cuda12/lib/python3.11/site-packages/catalyst/compiler.py", line 504, in run
return self.run_from_ir(
^^^^^^^^^^^^^^^^^
File "/home/schan/miniforge3/envs/quantum-cuda12/lib/python3.11/site-packages/catalyst/compiler.py", line 469, in run_from_ir
raise CompileError(*e.args) from e
catalyst.utils.exceptions.CompileError: Compilation failed:
deriv_batched_fn:8:3: error: operand #1 does not dominate this use
func.func private @batched_fn(%arg0: tensor<5xf64>, %arg1: tensor<5xf64>, %arg2: tensor<5x4xf64>, %arg3: tensor<5x3xf64>) -> tensor<4xf64> attributes {llvm.linkage = #llvm.linkage<internal>} {
^
deriv_batched_fn:8:3: note: see current operation: %7 = "func.call"(%arg1, %21, %arg3) <{callee = @circuit.pcount}> : (tensor<5xf64>, tensor<5xf64>, tensor<5x3xf64>) -> index
deriv_batched_fn:33:13: note: operand defined here (op in a child region)
%21 = func.call @_take(%arg2, %20) : (tensor<5x4xf64>, tensor<i64>) -> tensor<5xf64>
^
While processing 'GradientLoweringPass' pass of the 'QuantumCompilationPass' pipeline
Failed to lower MLIR module
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.