I am using the “strawberryfields.fock” device and tried to use the Jax interface in Qnode. However, I get the error: unhashable type: ‘numpy.ndarray’. This is caused in BeamSplitter as well as Displacement gate.
Is there support for Jax in Strawberryfields plugin or the original library?
Hey @DEVANSHU_GARG! Welcome back
It sounds like it’s something else. This code doesn’t reproduce the error you’re getting:
import pennylane as qml
dev = qml.device("strawberryfields.fock", wires=2, cutoff_dim=10)
@qml.qnode(dev, interface="jax")
def circuit():
qml.Beamsplitter(0.1, 0.2, wires=[0, 1])
return qml.expval(qml.NumberOperator(0))
print(circuit()) # Array(0., dtype=float32)
Make sure that you update to the latest version of PennyLane (v0.30 as of right now) and the PL-SF plugin. If that doesn’t solve it, I’ll need to see your full code
Just as a warning, though, the PL-SF plugin will not be supported in newer versions of Pennylane. It is compatible with versions of PennyLane up to and including 0.29. Just be aware of that
Alternatively, you can use Strawberryfields, but you won’t have Jax support.
Thank you @isaacdevlugt .
I want to use jax array for circuit parameters. Thus the error is generated when I pass a jax array.
import pennylane as qml
import jax.random as random
dev = qml.device("strawberryfields.fock", wires=2, cutoff_dim=10)
key = random.PRNGKey(0)
x = random.uniform(key, shape=(1, 2))
@qml.qnode(dev, interface="jax")
def circuit(init):
qml.Beamsplitter(init[0], init[1], wires=[0, 1])
return qml.expval(qml.NumberOperator(0))
print(circuit(x))
Thank you for this information @isaacdevlugt . Maybe I will test this out and migrate to Strawberryfields.
Oh! Looks like you’re not properly accessing the elements of x
. If you just want two random numbers, you can do x = random.uniform(key, shape=(2,))
:
import pennylane as qml
import jax.random as random
dev = qml.device("strawberryfields.fock", wires=2, cutoff_dim=10)
key = random.PRNGKey(0)
x = random.uniform(key, shape=(2,))
@qml.qnode(dev, interface="jax")
def circuit(init):
qml.Beamsplitter(init[0], init[1], wires=[0, 1])
return qml.expval(qml.NumberOperator(0))
print(circuit(x))
That should work!
Thank you for your help @isaacdevlugt ! The code snippet is working now.
I tried to use the function with vmap but I got this error again. However again seems like a problem with dimension specification but I can’t figure the right one. The second parameter is kept constant in all batches and the first one is to be iterated.
import pennylane as qml
import jax.random as random
dev = qml.device("strawberryfields.fock", wires=2, cutoff_dim=10)
key = random.PRNGKey(0)
key2 = random.PRNGKey(1)
x = random.uniform(key, shape=(2,))
y = random.uniform(key2, shape=(10,1))
state = jnp.array([0.2])
@qml.qnode(dev, interface="jax")
def circuit(phi, theta):
qml.Beamsplitter(theta[0], theta[1], wires=[0, 1])
qml.Displacement(phi, np.pi, wires=0)
return qml.expval(qml.NumberOperator(0))
build_circuit = vmap(circuit, in_axes = (0, None), out_axes = 0)
print(build_circuit(y, x))
Interesting… Alternatively, you can use qml.batch_input
, just be aware that this might be deprecated soon (although, since PL-SF is also being deprecated, using batch_input
should be fine with v0.29).
from jax import numpy as jnp
import pennylane as qml
dev = qml.device("strawberryfields.fock", wires=2, cutoff_dim=5)
@qml.batch_input(argnum=0)
@qml.qnode(dev, interface="jax")
def circuit(phi, theta):
qml.Displacement(phi, 3.14, wires=0)
qml.Beamsplitter(*theta, wires=[0, 1])
return qml.expval(qml.NumberOperator(0))
phis = jnp.array([0.1, 0.2, 0.3, 0.4])
thetas = jnp.array([0.5, 0.6])
circuit(phis, thetas)
Thank you @isaacdevlugt .
This was a great help.
Awesome! Glad we were able to help!