Running quantum circuit in batches using jax.vmap on lighting.gpu device

Hi @CatalinaAlbornoz ,

So looks like the solution is to add the “@qml.qjit” decorator to my function?

My code now looks like

import jax
import jax.numpy as jnp
import pennylane as qml

# Added to silence some warnings.
jax.config.update("jax_enable_x64", True)

dev = qml.device("lightning.gpu", wires=2)

@qml.qjit(autograph=True) # Added this now
@qml.qnode(dev, interface="jax")
def circuit(param):
    # These two gates represent our QML model.
    qml.RX(param, wires=0)
    qml.CNOT(wires=[0, 1])

    # The expval here will be the "cost function" we try to minimize.
    # Usually, this would be defined by the problem we want to solve,
    # but for this example we'll just use a single PauliZ.
    return qml.expval(qml.PauliZ(0))

print("\n\nBatching and Evolutionary Strategies")
print("------------------------------------")

# Create a vectorized version of our original circuit.
vcircuit = jax.vmap(circuit)

# Now, we call the ``vcircuit`` with multiple parameters at once and get back a
# batch of expectations.
# This examples runs 3 quantum circuits in parallel.
batch_params = jnp.array([1.02, 0.123, -0.571])

batched_results = vcircuit(batch_params)
print(f"Batched result: {batched_results}")

And I get the expected results

Batching and Evolutionary Strategies
------------------------------------
Batched result: [0.52336595 0.99244503 0.84136092]

I can’t figure out why this would work, any insight is helpful!