JAX with default.mixed device

Hi,

I have a question regarding the JAX interface and whether it can be used for simulating noisy quantum computers. I am interested in JAX because of the jax.vmap function for parallelizing the quantum circuit computations.

What I tried to do is the following. First I defined a noisy quantum circuit such as

def noisy_circuit(prob,**kwargs):
>>>>for k in range(len(G.nodes)):
>>>>>>>>qml.BitFlip(prob, wires=k)
>>>>return qml.expval(qml.PauliZ(0))

and then parallelize it and run it as follows:

dev = qml.device("default.mixed", wires = len(G.nodes))
qcircuit = qml.QNode(noisy_circuit, dev, interface = "jax")
vcircuit = jax.vmap(qcircuit)
whereGis a graph that I defined with thenetworkx` package.

If I define probs = jax.array([0., 0.05, 0.1]), then vcircuit(probs) gives rise to the following error NotImplementedError: batching rules are implemented only for id_tap, not for call.. However, qcircuit(0.01) works perfectly, as expected.

Thank you very much in advance!

Cheers,
Javier.

P.D.: Sorry for the >>>> thing.

1 Like

Hi @Javier,

Welcome back, thanks for the question! :slight_smile:

It’s exciting to see this use case! The underlying issue will likely be, that internally in PennyLane, only the default.qubit.jax device supports Jax computations natively. The jax.vmap function should be compatible with that device. For other devices, there is an internal conversion in PennyLane between the interface-specific objects (e.g., jnp.array) and ndarray objects. This conversion likely breaks the batching logic of Jax.

To make this feature compatible with default.mixed, a likely, but a complex solution would be creating a device using Jax only, that inherits from DefaultMixed and only uses Jax data structures under the hood (akin to default.qubit.jax). It’s good noting that there could be potential use to having such a device.

1 Like

Hi again,

related to default.mixed, if I want to run a quantum circuit with this device using the GPU, do I have to change something in my code? (I must say that i have never used a GPU together with python). Let’s say that I have the noisy_circuit that I defined in my previous message and now I do the following:

dev = qml.device("default.mixed", wires = len(G.nodes))
qcircuit = qml.QNode(noisy_circuit, dev, interface = "autograd")
result = qcircuit(0.01)

The idea would be to modify the circuit introducing some parameters and then do some gradient descent rounds, but for now I want to keep it simple.

Thank you very much in advance!!

Cheers,
Javier.

Hey @Javier!

if I want to run a quantum circuit with this device using the GPU, do I have to change something in my code?

Unfortunately default.mixed is not currently compatible with running on GPU. It works using an all NumPy-based implementation. Though, as @antalszava pointed out, it’s great to have this as a feature request!

The idea would be to modify the circuit introducing some parameters and then do some gradient descent rounds, but for now I want to keep it simple.

That should still be possible with default.mixed, let us know if you need a hand with that part!

1 Like

Hi @NikSchet,

An update here: we’ve recently focused on allowing backpropagation support for default.mixed and enhanced compatibility with JAX’s vmap.

See the following (approved) pull request that will be merged into master and is to be released with v0.25.0 of PennyLane: https://github.com/PennyLaneAI/pennylane/pull/2776

Furthermore, we also have GPU support for default.mixed with Torch on the radar.