Hi there,
I would like to do the following: (with the code below the problem can be reproduced). I have a circuit with a certain binary input X
that creates an initial state, followed by some arbitrary circuit from which I sample in the Z basis. In JAX I can use the vmap
function so I can input a whole batch X_batch
of X
and compute the sampling for each X
at the same time
The problem is, if X_batch
contains the same X
several times, I sample exactly the same outputs for each of them.
I am not sure if this is a bug or a feature, but I would like the sampling to be random also if the inputs are the same. Is there any way of telling vmap to keep the samples random?
With the following code, the problem can be reproduced
import pennylane as qml
import jax
import jax.numpy as jnp
import numpy as np
dev = qml.device('default.qubit.jax', wires=4, shots = 100)
@jax.jit # QNode calls will now be jitted, and should run faster.
@qml.qnode(dev, interface='jax')
def qnode(params, inputs):
layers, qubits, _ = params.shape
for i in range(qubits):
qml.RX(jnp.pi/2*inputs[i], wires=i)
qml.templates.layers.StronglyEntanglingLayers(params, wires=range(4))
return [qml.sample(qml.PauliZ(i)) for i in range(qubits)]
params_shape = (4, 4, 3)
key = jax.random.PRNGKey(1)
key, subkey = jax.random.split(key)
params = jax.random.uniform(subkey, params_shape)
qnode_vmap = jax.vmap(qnode, in_axes=(None, 0), out_axes=0)
X_in = np.zeros((3, 4)) # a batch of 3 inputs X
X_in[2][0] = 1. # change the 3rd X
X_in[2][1] = 1.
Y_vmap = qnode_vmap(params, X_in)
Y_vmap_T = jnp.transpose(Y_vmap, axes=(0,2,1))
compare = (jnp.isclose(Y_vmap_T[0], Y_vmap_T[1])*1).mean()
One can see the compare = 1.
which means that all the 1000 samples for input X_batch[0]
and X_batch[1]
are identical. But I would like to have individual samples for each input.
Is this somehow possible?