If you really need true batching/broadcasting, I suggest using default.qubit
or you can maybe try using jax.pmap
:
import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=6'
import jax
jax.devices("cpu")
import pennylane as qml
from pennylane import numpy as np
from jax.config import config as jax_config
jax_config.update("jax_enable_x64", True)
dev = qml.device('lightning.qubit', wires=2)
@jax.jit
@qml.qnode(dev, diff_method="parameter-shift")
def circuit(inputs):
qml.RX(inputs, 0)
return [qml.expval(qml.PauliZ(q)) for q in range(2)]
inputs = np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6])
batched_qnode = jax.pmap(circuit, in_axes=(0))
print(batched_qnode(inputs))
Although, the overall utility of this^ will be debatable.