I would like to do the following: (with the code below the problem can be reproduced). I have a circuit with a certain binary input
X that creates an initial state, followed by some arbitrary circuit from which I sample in the Z basis. In JAX I can use the
vmap function so I can input a whole batch
X and compute the sampling for each
X at the same time
The problem is, if
X_batch contains the same
X several times, I sample exactly the same outputs for each of them.
I am not sure if this is a bug or a feature, but I would like the sampling to be random also if the inputs are the same. Is there any way of telling vmap to keep the samples random?
With the following code, the problem can be reproduced
import pennylane as qml import jax import jax.numpy as jnp import numpy as np dev = qml.device('default.qubit.jax', wires=4, shots = 100) @jax.jit # QNode calls will now be jitted, and should run faster. @qml.qnode(dev, interface='jax') def qnode(params, inputs): layers, qubits, _ = params.shape for i in range(qubits): qml.RX(jnp.pi/2*inputs[i], wires=i) qml.templates.layers.StronglyEntanglingLayers(params, wires=range(4)) return [qml.sample(qml.PauliZ(i)) for i in range(qubits)] params_shape = (4, 4, 3) key = jax.random.PRNGKey(1) key, subkey = jax.random.split(key) params = jax.random.uniform(subkey, params_shape) qnode_vmap = jax.vmap(qnode, in_axes=(None, 0), out_axes=0) X_in = np.zeros((3, 4)) # a batch of 3 inputs X X_in = 1. # change the 3rd X X_in = 1. Y_vmap = qnode_vmap(params, X_in) Y_vmap_T = jnp.transpose(Y_vmap, axes=(0,2,1)) compare = (jnp.isclose(Y_vmap_T, Y_vmap_T)*1).mean()
One can see the
compare = 1. which means that all the 1000 samples for input
X_batch are identical. But I would like to have individual samples for each input.
Is this somehow possible?