Hi there,

I would like to do the following: (with the code below the problem can be reproduced). I have a circuit with a certain binary input `X`

that creates an initial state, followed by some arbitrary circuit from which I sample in the Z basis. In JAX I can use the `vmap`

function so I can input a whole batch `X_batch`

of `X`

and compute the sampling for each `X`

at the same time

The problem is, if `X_batch`

contains the same `X`

several times, I sample exactly the same outputs for each of them.

I am not sure if this is a bug or a feature, but I would like the sampling to be random also if the inputs are the same. Is there any way of telling vmap to keep the samples random?

With the following code, the problem can be reproduced

```
import pennylane as qml
import jax
import jax.numpy as jnp
import numpy as np
dev = qml.device('default.qubit.jax', wires=4, shots = 100)
@jax.jit # QNode calls will now be jitted, and should run faster.
@qml.qnode(dev, interface='jax')
def qnode(params, inputs):
layers, qubits, _ = params.shape
for i in range(qubits):
qml.RX(jnp.pi/2*inputs[i], wires=i)
qml.templates.layers.StronglyEntanglingLayers(params, wires=range(4))
return [qml.sample(qml.PauliZ(i)) for i in range(qubits)]
params_shape = (4, 4, 3)
key = jax.random.PRNGKey(1)
key, subkey = jax.random.split(key)
params = jax.random.uniform(subkey, params_shape)
qnode_vmap = jax.vmap(qnode, in_axes=(None, 0), out_axes=0)
X_in = np.zeros((3, 4)) # a batch of 3 inputs X
X_in[2][0] = 1. # change the 3rd X
X_in[2][1] = 1.
Y_vmap = qnode_vmap(params, X_in)
Y_vmap_T = jnp.transpose(Y_vmap, axes=(0,2,1))
compare = (jnp.isclose(Y_vmap_T[0], Y_vmap_T[1])*1).mean()
```

One can see the `compare = 1.`

which means that all the 1000 samples for input `X_batch[0]`

and `X_batch[1]`

are identical. But I would like to have individual samples for each input.

Is this somehow possible?