Jax support for Strawberryfields plugin

I am using the “strawberryfields.fock” device and tried to use the Jax interface in Qnode. However, I get the error: unhashable type: ‘numpy.ndarray’. This is caused in BeamSplitter as well as Displacement gate.
Is there support for Jax in Strawberryfields plugin or the original library?

Hey @DEVANSHU_GARG! Welcome back :rocket:

It sounds like it’s something else. This code doesn’t reproduce the error you’re getting:

import pennylane as qml

dev = qml.device("strawberryfields.fock", wires=2, cutoff_dim=10)

@qml.qnode(dev, interface="jax")
def circuit():
  qml.Beamsplitter(0.1, 0.2, wires=[0, 1])
  return qml.expval(qml.NumberOperator(0))

print(circuit()) # Array(0., dtype=float32)

Make sure that you update to the latest version of PennyLane (v0.30 as of right now) and the PL-SF plugin. If that doesn’t solve it, I’ll need to see your full code :slight_smile:

1 Like

Just as a warning, though, the PL-SF plugin will not be supported in newer versions of Pennylane. It is compatible with versions of PennyLane up to and including 0.29. Just be aware of that :slight_smile:

Alternatively, you can use Strawberryfields, but you won’t have Jax support.

Thank you @isaacdevlugt :smile:.
I want to use jax array for circuit parameters. Thus the error is generated when I pass a jax array.

import pennylane as qml
import jax.random as random

dev = qml.device("strawberryfields.fock", wires=2, cutoff_dim=10)
key = random.PRNGKey(0)
x = random.uniform(key, shape=(1, 2))

@qml.qnode(dev, interface="jax")
def circuit(init):
  qml.Beamsplitter(init[0], init[1], wires=[0, 1])
  return qml.expval(qml.NumberOperator(0))
print(circuit(x))

Thank you for this information @isaacdevlugt . Maybe I will test this out and migrate to Strawberryfields.

Oh! Looks like you’re not properly accessing the elements of x. If you just want two random numbers, you can do x = random.uniform(key, shape=(2,)):

import pennylane as qml
import jax.random as random

dev = qml.device("strawberryfields.fock", wires=2, cutoff_dim=10)
key = random.PRNGKey(0)
x = random.uniform(key, shape=(2,))

@qml.qnode(dev, interface="jax")
def circuit(init):
  qml.Beamsplitter(init[0], init[1], wires=[0, 1])
  return qml.expval(qml.NumberOperator(0))

print(circuit(x))

That should work!

Thank you for your help @isaacdevlugt ! The code snippet is working now.
I tried to use the function with vmap but I got this error again. However again seems like a problem with dimension specification but I can’t figure the right one. The second parameter is kept constant in all batches and the first one is to be iterated.

import pennylane as qml
import jax.random as random

dev = qml.device("strawberryfields.fock", wires=2, cutoff_dim=10)
key = random.PRNGKey(0)
key2 = random.PRNGKey(1)
x = random.uniform(key, shape=(2,))
y = random.uniform(key2, shape=(10,1))
state = jnp.array([0.2])

@qml.qnode(dev, interface="jax")
def circuit(phi, theta):
    qml.Beamsplitter(theta[0], theta[1], wires=[0, 1])
    qml.Displacement(phi, np.pi, wires=0)
    return qml.expval(qml.NumberOperator(0))

build_circuit = vmap(circuit, in_axes = (0, None), out_axes = 0)
print(build_circuit(y, x))

Interesting… Alternatively, you can use qml.batch_input, just be aware that this might be deprecated soon (although, since PL-SF is also being deprecated, using batch_input should be fine with v0.29).

from jax import numpy as jnp
import pennylane as qml

dev = qml.device("strawberryfields.fock", wires=2, cutoff_dim=5)

@qml.batch_input(argnum=0)
@qml.qnode(dev, interface="jax")
def circuit(phi, theta):
    qml.Displacement(phi, 3.14, wires=0)
    qml.Beamsplitter(*theta, wires=[0, 1])
    return qml.expval(qml.NumberOperator(0))

phis = jnp.array([0.1, 0.2, 0.3, 0.4])
thetas = jnp.array([0.5, 0.6])
circuit(phis, thetas)
1 Like

Thank you @isaacdevlugt .
This was a great help.

1 Like

Awesome! Glad we were able to help!