Interesting… Alternatively, you can use qml.batch_input
, just be aware that this might be deprecated soon (although, since PL-SF is also being deprecated, using batch_input
should be fine with v0.29).
from jax import numpy as jnp
import pennylane as qml
dev = qml.device("strawberryfields.fock", wires=2, cutoff_dim=5)
@qml.batch_input(argnum=0)
@qml.qnode(dev, interface="jax")
def circuit(phi, theta):
qml.Displacement(phi, 3.14, wires=0)
qml.Beamsplitter(*theta, wires=[0, 1])
return qml.expval(qml.NumberOperator(0))
phis = jnp.array([0.1, 0.2, 0.3, 0.4])
thetas = jnp.array([0.5, 0.6])
circuit(phis, thetas)