Default.qubit.jax and weighted sampling

Hi, I am currently trying to add a weighted sampling to my VQE using JAX. Indeed, I want to sample accoring to coefficients stored in a jax device array. However, it seems that it’s not possible to specify the number of shots with a jax device array since when I try

@jax.jit
def Circuits_Sample_Ub(key, params, inputs, nbr_shot):
    dev = qml.device('default.qubit.jax', wires=n_qubits//2, shots=nbr_shot, prng_key=key)  
    @jax.jit
    @qml.qnode(dev, interface='jax', diff_method=None)
    def qnode(params, inputs):
        for i in range(n_qubits//2):
          qml.RX(jnp.pi*inputs[i], wires=i)
        brick_wall_entangling(params)
        return qml.sample()
    return qnode(params, inputs)

Circuits_Sample_Ub(key, params_A, bitstringA[0], (coef[0]**2*1000).astype(int))

I get
DeviceError: Shots must be a single non-negative integer or a sequence of non-negative integers.
and this even if I try without jititng. Is there a solution? Thanks a lot.
Paulin

Hi @paulin_ds,

You can in fact specify the number of shots, in fact in our Using JAX with PennyLane demo we have a section on using shots.

However if you want to pass it as a parameter of Circuits_Sample_Ub you need to specify that it’s a static argument. Otherwise Jax will introduce an abstract tracer for it. You can learn more about this in the Jax documentation here.

Here’s an example using this. As you can see in the jax.jit docs I used the @partial decorator in order to specify the index of the static argument. I’m not sure if this will break something in other parts of your code but you can give it a try.

from functools import partial

@partial(jax.jit,static_argnums=3)
def Circuits_Sample_Ub(key, params, inputs, nbr_shot):
    dev = qml.device('default.qubit.jax', wires=n_qubits//2, shots=nbr_shot, prng_key=key)  
    @jax.jit
    @qml.qnode(dev, interface='jax', diff_method=None)
    def qnode(params, inputs):
        for i in range(n_qubits//2):
            qml.RX(jnp.pi*inputs[i], wires=i)
        #brick_wall_entangling(params)
        qml.BasicEntanglerLayers(weights=params, wires=range(n_qubits//2)) # Added this
        return qml.sample()
    return qnode(params, inputs)

## -- Added this ---
key = jax.random.PRNGKey(0)
shape = qml.BasicEntanglerLayers.shape(n_layers=2, n_wires=n_qubits//2)
params_A = np.random.random(size=shape)
inputs = np.ones(n_qubits//2)
nbr_shot = 10
## ---

Circuits_Sample_Ub(key, params_A, inputs, nbr_shot)

You will see that this example is similar to your code but I had to add some things since I didn’t have access to all of your functions.

I hope this helps you!

Yes of course. But I think that if we make this argument static, if we change the number of shot, the function recompile. This make an adaptative sampling very unefficient. My question was more: would it be possible to have this argument non-static (thus it could be traced by jax and it wouldnt have to be recompiled everytime)?
Thanks a lot

Hi @paulin_ds, it is a known issues that jitting the device create problems with shots/wires. That’s something we want to make better in the future. I would suggest you to create the device outside the jitted function as below. It will recompile if you change the number of shots but at the moment we do not have a better solution unfortunately.

@jax.jit
def Circuits_Sample_Ub(params, inputs):

    @qml.qnode(dev, interface='jax', diff_method=None)
    def qnode(params, inputs):
        for i in range(n_qubits//2):
            qml.RX(jnp.pi*inputs[i], wires=i)
        #brick_wall_entangling(params)
        qml.BasicEntanglerLayers(weights=params, wires=range(n_qubits//2)) # Added this
        return qml.sample()
    return qnode(params, inputs)


## -- Added this ---
n_qubits=2
key = jax.random.PRNGKey(0)
shape = qml.BasicEntanglerLayers.shape(n_layers=2, n_wires=n_qubits//2)
params_A = np.random.random(size=shape)
inputs = np.ones(n_qubits//2)
nbr_shot = 10
## ---
dev = qml.device('default.qubit.jax', wires=n_qubits // 2, shots=nbr_shot, prng_key=key)
Circuits_Sample_Ub(params_A, inputs)
1 Like

Ok I understand. Thanks anyway for these nice explanations.
good continuation,
Paulin

Thank you very much for your question @paulin_ds. This helps us to make PennyLane better :slightly_smiling_face: . Please let us know if you find other use cases that are not supported.