Sampling with JAX and vmap not random anymore

Hi there,

I would like to do the following: (with the code below the problem can be reproduced). I have a circuit with a certain binary input X that creates an initial state, followed by some arbitrary circuit from which I sample in the Z basis. In JAX I can use the vmap function so I can input a whole batch X_batch of X and compute the sampling for each X at the same time
The problem is, if X_batch contains the same X several times, I sample exactly the same outputs for each of them.
I am not sure if this is a bug or a feature, but I would like the sampling to be random also if the inputs are the same. Is there any way of telling vmap to keep the samples random?

With the following code, the problem can be reproduced

import pennylane as qml
import jax
import jax.numpy as jnp
import numpy as np

dev = qml.device('default.qubit.jax', wires=4, shots = 100)

@jax.jit  # QNode calls will now be jitted, and should run faster.
@qml.qnode(dev, interface='jax')
def qnode(params, inputs):
    layers, qubits, _ = params.shape
    for i in range(qubits):
        qml.RX(jnp.pi/2*inputs[i], wires=i)
    qml.templates.layers.StronglyEntanglingLayers(params, wires=range(4))
    return [qml.sample(qml.PauliZ(i)) for i in range(qubits)]

params_shape = (4, 4, 3)

key = jax.random.PRNGKey(1)
key, subkey = jax.random.split(key)
params = jax.random.uniform(subkey, params_shape)

qnode_vmap = jax.vmap(qnode, in_axes=(None, 0), out_axes=0)

X_in = np.zeros((3, 4)) # a batch of 3 inputs X
X_in[2][0] = 1. # change the 3rd X
X_in[2][1] = 1.
Y_vmap = qnode_vmap(params, X_in)


Y_vmap_T = jnp.transpose(Y_vmap, axes=(0,2,1))

compare = (jnp.isclose(Y_vmap_T[0], Y_vmap_T[1])*1).mean()

One can see the compare = 1. which means that all the 1000 samples for input X_batch[0] and X_batch[1] are identical. But I would like to have individual samples for each input.
Is this somehow possible?

Hi @PatrickHuembeli! I unfortunately don’t have that much experience with JAX’s jit and vmap, but I’m curious if the approach taking in the following demo,

namely, defining the device and the QNode inside the jitted function, will solve the problem in your case?

Hi @josh, thanks for the reply. It gave me the last hint that was necessary. This solution is for a single executions of the sampling, but I figured out how to do the vmapping with random keys. To do so, you define the circuit as described in the link that you added. I slightly modified it and added an input x_in to use it with vmap later:

@jax.jit
def circuit(param, x_in, key):
    # Notice how the device construction now happens within the jitted method.
    # Also note the added '.jax' to the device path.
    dev = qml.device("default.qubit.jax", wires=2, shots=10, prng_key=key)

    # Now we can create our qnode within the circuit function.
    @qml.qnode(dev, interface="jax", diff_method=None)
    def my_circuit():
        qml.RY(x_in*np.pi, wires =0)
        qml.RX(param, wires=0)
        qml.CNOT(wires=[0, 1])
        return qml.sample(qml.PauliZ(0))
    return my_circuit()

key1 = jax.random.PRNGKey(0)

Normally I would do the vmapping as follows, with in_axes = (None, 0, None). Which means I only expect x_in to come as a batch. The “trick” is, that you have to provide a separate key for each x_in

circuit_vmap = jax.vmap(circuit , in_axes=(None, 0, 0), out_axes=0)

To run now the vmapped circuit for a X_batch with dimension (batch_size, 1) for this example one needs to generate also batch_size keys.

subkeys = jax.random.split(key, batch_size)

And run the circuit with:

 circuit_vmap(params, X_batch, subkeys)
1 Like

Glad you got it working @PatrickHuembeli! And thanks for posting your solution, this will be helpful for anyone with the same problem (and me as well :slight_smile:)