Jax support for Strawberryfields plugin

Interesting… Alternatively, you can use qml.batch_input, just be aware that this might be deprecated soon (although, since PL-SF is also being deprecated, using batch_input should be fine with v0.29).

from jax import numpy as jnp
import pennylane as qml

dev = qml.device("strawberryfields.fock", wires=2, cutoff_dim=5)

@qml.batch_input(argnum=0)
@qml.qnode(dev, interface="jax")
def circuit(phi, theta):
    qml.Displacement(phi, 3.14, wires=0)
    qml.Beamsplitter(*theta, wires=[0, 1])
    return qml.expval(qml.NumberOperator(0))

phis = jnp.array([0.1, 0.2, 0.3, 0.4])
thetas = jnp.array([0.5, 0.6])
circuit(phis, thetas)
1 Like