Incompatible function arguments error on lightning.qubit with JAX

I tried to use lightning.qubit with JAX but got the following error:

INTERNAL: Generated function failed: CpuCallback error: TypeError: RX(): incompatible function arguments. The following argument types are supported:
    1. (self: pennylane_lightning.lightning_qubit_ops.StateVectorC128, arg0: List[int], arg1: bool, arg2: List[float]) -> None

Invoked with: <pennylane_lightning.lightning_qubit_ops.StateVectorC128 object at 0x7f3530049770>, [0], False, [array([3., 4.])]

Here is the code to reproduce it:

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

import numpy as np
import jax
import pennylane as qml
import os
from jax.config import config as jax_config
jax_config.update("jax_enable_x64", True)

print(jax.devices())

dev = qml.device('lightning.qubit', wires=4)

def circuit(inputs):
    
    for q in range(4):
        qml.RX(inputs[q], wires=q)

    return [qml.expval(qml.PauliZ(q)) for q in range(4)]

inputs = np.floor(np.random.uniform(size=(2,4))*10)
print(inputs)

qnode = jax.jit(qml.QNode(circuit, device=dev, diff_method='best'))
batched_qnode = jax.vmap(qnode, in_axes=(0))
batched_qnode(inputs)

The code works fine with default.qubit. Any help is appreciated.
These are the library versions I am using:

jax                          0.4.8
jaxlib                       0.4.7+cuda11.cudnn86
PennyLane                    0.29.1
PennyLane-Lightning          0.29.0

Hey @Gopal_Dahale! Welcome back!

I can replicate your error and am looking into this. Will get back to you ASAP!

1 Like

Hey @Gopal_Dahale! The error message here isn’t too too helpful, but even though jax.vmap seems like it should handle batching on its own (i.e., without PennyLane), it turns out that thi isn’t the case and Lightning devices do not support batching altogether. You will have to serialize your code for now or use default.qubit if batching is necessary.

Hope this helps!

Instead of using jax.vmap, if I simply pass inputs to qnode then it runs without any error. Does batching work in this case? Also, can I use batch_input transformation?

I needed clarification from the performance team on this myself, but the behaviour you’re seeing when jax.vmap isn’t used appears to show that parameter broadcasting is working. However, what’s actually going on under the hood is that the circuit is being executed serially and displayed :sweat_smile:. This is happening with or without jax.jit.

For example, with this circuit

dev = qml.device('lightning.qubit', wires=1)

@qml.qnode(dev)
def circuit(inputs):   
    qml.RX(inputs, wires=0)
    return qml.expval(qml.PauliZ(0))

if I do

inputs = np.array([0.1, 0.2, 0.3])
circuit(inputs)

or

inputs = np.array([0.1, 0.2, 0.3])

for val in inputs:
    circuit(val)

they’re effectively doing the same thing. So, this isn’t true broadcasting (true parallel execution).

Hope this clears things up!

If you really need true batching/broadcasting, I suggest using default.qubit or you can maybe try using jax.pmap:

import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=6'

import jax
jax.devices("cpu")
import pennylane as qml
from pennylane import numpy as np

from jax.config import config as jax_config

jax_config.update("jax_enable_x64", True)

dev = qml.device('lightning.qubit', wires=2)

@jax.jit
@qml.qnode(dev, diff_method="parameter-shift")
def circuit(inputs):
    qml.RX(inputs, 0)
    return [qml.expval(qml.PauliZ(q)) for q in range(2)]

inputs = np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6])

batched_qnode = jax.pmap(circuit, in_axes=(0))
print(batched_qnode(inputs))

Although, the overall utility of this^ will be debatable.

Thanks for the response @isaacdevlugt. I surely can use default.qubit but it’s performance will not scale with the number of qubits. Also, my main aim is to execute the circuit on GPUs so jax.pmap seems a nice alternative but have to look into the sharding stuff to optimize this further.

Thank you very much @isaacdevlugt.

1 Like

Happy to help!

It actually turns out that when an attempt to implement parameter broadcasting with lightning.qubit was made, the performance of it slowed execution down in the regimes where it was expected to be advantageous (~20 qubits). It was scrapped as a result :slight_smile:. Hence the suggestion to use default.qubit if you absolutely need broadcasting.

Best of luck!

1 Like