Qml.counts is not working with jax.jit

@qml.qnode(dev, interface="jax")
def circuit(param):
    # These two gates represent our QML model.
    qml.RX(param[0], wires=0)
    qml.CNOT(wires=[0, 1])

    return qml.counts(wires=[0],all_outcomes=True)

print(f"Result: {repr(jax.jit(circuit)(np.array([0.123])))}")

Hi Team, I’m trying to use jax.jit in circuit which returns qml.counts(). But It throws the following error. Over here, I’m just replicating my error in a simple circuit. Required your guidence on it.

Hi @roysuman088, thank you for posting this here. There does seem to be an issue with counts. I can get it working with “sample” though. You could calculate the counts from the samples. Here’s the code with the recommended construction for jax.jit QNodes with randomness.

import pennylane as qml
import jax
import jax.numpy as jnp
from jax.config import config
config.update("jax_enable_x64", True)

@jax.jit
def sample_circuit(param, key):
    # Device construction should happen inside a `jax.jit` decorated
    # method when using a PRNGKey.
    dev = qml.device('default.qubit.jax', wires=2, prng_key=key, shots=100)
    
    @qml.qnode(dev, interface="jax", diff_method=None)
    def circuit(param):
        # These two gates represent our QML model.
        qml.RX(param, wires=0)
        qml.CNOT(wires=[0, 1])
        return qml.sample(wires=0)#qml.counts(wires=[0],all_outcomes=True)
    
    return circuit(param)

samples = sample_circuit(jnp.array([1.123]), jax.random.PRNGKey(0))
print(f"Result: {samples}")

Note that I have changed where and how the device is created, I’ve added diff_method=None, and I’ve returned qml.sample instead of qml.counts.

Please let me know if you have any questions about this!

Hi @roysuman088,

My colleagues found the sources of the bug! In fact there were 2 bugs:

  1. Broadcasting and counts aren’t working well together.
  2. Jitting and counts aren’t working well together.

Broadcasting:
When you add the list of parameters as jnp.array([1.123]) (notice the square brackets) PennyLane thinks you’re broadcasting. This means sending several values of your inputs at the same time, so the circuit gets run several times with just one call. If instead you remove the square brackets jnp.array(1.123) then PennyLane understands that this is just one value of the parameter, just as always. For some reason broadcasting and counts aren’t working well together. We have opened an issue for this here.

Jitting and counts:
Counts and jitting aren’t working well together at the moment. We have opened an issue here.

The following code, without broadcasting, without jitting, and with counts, should work for you.

import pennylane as qml
import jax
import jax.numpy as jnp
from jax.config import config
config.update("jax_enable_x64", True)

#@jax.jit
def sample_circuit(param, key):
    # Device construction should happen inside a `jax.jit` decorated
    # method when using a PRNGKey.
    dev = qml.device('default.qubit.jax', wires=2, prng_key=key, shots=100)
    
    @qml.qnode(dev, interface="jax", diff_method=None)
    def circuit(param):
        # These two gates represent our QML model.
        qml.RX(param, wires=0)
        qml.CNOT(wires=[0, 1])
        return qml.counts(wires=0)
    
    return circuit(param)

samples = sample_circuit(jnp.array(1.123), jax.random.PRNGKey(0))
print(f"Result: {samples}")

Please let me know if you have any questions about this!

Hi @CatalinaAlbornoz … thanks for the solution and also it’s good that you guys open an issue for it. I have been working on higher qubits…and to boost up performance i need to execute my circuit by jitting… so for the time being only we can go for only samples for this type of operations where count is required…right? Just to clarify…

Hi @roysuman088,

You could also use shots=1 and use qml.expval or qml.probs. I’m not sure if this would be faster or not but you could try this on a small example to test.

Hi @CatalinaAlbornoz …thanks for the update…but while trying the first suggested code, I’m getting the following error related to jax incompatibility with pennylane for both 0.4.3 and 0.4.4 version of jax. Can you please look at it.

For jax version 0.4.3

For jax version 0.4.4

Hi @roysuman088! Please could you try:

pip install -U jax jaxlib

A more recent version of jax has been released that should solve this compatibility issue.

2 Likes