I tried to use lightning.qubit
with JAX but got the following error:
INTERNAL: Generated function failed: CpuCallback error: TypeError: RX(): incompatible function arguments. The following argument types are supported:
1. (self: pennylane_lightning.lightning_qubit_ops.StateVectorC128, arg0: List[int], arg1: bool, arg2: List[float]) -> None
Invoked with: <pennylane_lightning.lightning_qubit_ops.StateVectorC128 object at 0x7f3530049770>, [0], False, [array([3., 4.])]
Here is the code to reproduce it:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
import numpy as np
import jax
import pennylane as qml
import os
from jax.config import config as jax_config
jax_config.update("jax_enable_x64", True)
print(jax.devices())
dev = qml.device('lightning.qubit', wires=4)
def circuit(inputs):
for q in range(4):
qml.RX(inputs[q], wires=q)
return [qml.expval(qml.PauliZ(q)) for q in range(4)]
inputs = np.floor(np.random.uniform(size=(2,4))*10)
print(inputs)
qnode = jax.jit(qml.QNode(circuit, device=dev, diff_method='best'))
batched_qnode = jax.vmap(qnode, in_axes=(0))
batched_qnode(inputs)
The code works fine with default.qubit
. Any help is appreciated.
These are the library versions I am using:
jax 0.4.8
jaxlib 0.4.7+cuda11.cudnn86
PennyLane 0.29.1
PennyLane-Lightning 0.29.0
Hey @Gopal_Dahale! Welcome back!
I can replicate your error and am looking into this. Will get back to you ASAP!
1 Like
Hey @Gopal_Dahale! The error message here isn’t too too helpful, but even though jax.vmap
seems like it should handle batching on its own (i.e., without PennyLane), it turns out that thi isn’t the case and Lightning devices do not support batching altogether. You will have to serialize your code for now or use default.qubit
if batching is necessary.
Hope this helps!
Instead of using jax.vmap
, if I simply pass inputs
to qnode
then it runs without any error. Does batching work in this case? Also, can I use batch_input
transformation?
I needed clarification from the performance team on this myself, but the behaviour you’re seeing when jax.vmap
isn’t used appears to show that parameter broadcasting is working. However, what’s actually going on under the hood is that the circuit is being executed serially and displayed
. This is happening with or without jax.jit
.
For example, with this circuit
dev = qml.device('lightning.qubit', wires=1)
@qml.qnode(dev)
def circuit(inputs):
qml.RX(inputs, wires=0)
return qml.expval(qml.PauliZ(0))
if I do
inputs = np.array([0.1, 0.2, 0.3])
circuit(inputs)
or
inputs = np.array([0.1, 0.2, 0.3])
for val in inputs:
circuit(val)
they’re effectively doing the same thing. So, this isn’t true broadcasting (true parallel execution).
Hope this clears things up!
If you really need true batching/broadcasting, I suggest using default.qubit
or you can maybe try using jax.pmap
:
import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=6'
import jax
jax.devices("cpu")
import pennylane as qml
from pennylane import numpy as np
from jax.config import config as jax_config
jax_config.update("jax_enable_x64", True)
dev = qml.device('lightning.qubit', wires=2)
@jax.jit
@qml.qnode(dev, diff_method="parameter-shift")
def circuit(inputs):
qml.RX(inputs, 0)
return [qml.expval(qml.PauliZ(q)) for q in range(2)]
inputs = np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6])
batched_qnode = jax.pmap(circuit, in_axes=(0))
print(batched_qnode(inputs))
Although, the overall utility of this^ will be debatable.
Thanks for the response @isaacdevlugt. I surely can use default.qubit
but it’s performance will not scale with the number of qubits. Also, my main aim is to execute the circuit on GPUs so jax.pmap
seems a nice alternative but have to look into the sharding stuff to optimize this further.
Thank you very much @isaacdevlugt.
1 Like
Happy to help!
It actually turns out that when an attempt to implement parameter broadcasting with lightning.qubit
was made, the performance of it slowed execution down in the regimes where it was expected to be advantageous (~20 qubits). It was scrapped as a result
. Hence the suggestion to use default.qubit
if you absolutely need broadcasting.
Best of luck!
1 Like