Hi @CatalinaAlbornoz ,
So looks like the solution is to add the “@qml.qjit
” decorator to my function?
My code now looks like
import jax
import jax.numpy as jnp
import pennylane as qml
# Added to silence some warnings.
jax.config.update("jax_enable_x64", True)
dev = qml.device("lightning.gpu", wires=2)
@qml.qjit(autograph=True) # Added this now
@qml.qnode(dev, interface="jax")
def circuit(param):
# These two gates represent our QML model.
qml.RX(param, wires=0)
qml.CNOT(wires=[0, 1])
# The expval here will be the "cost function" we try to minimize.
# Usually, this would be defined by the problem we want to solve,
# but for this example we'll just use a single PauliZ.
return qml.expval(qml.PauliZ(0))
print("\n\nBatching and Evolutionary Strategies")
print("------------------------------------")
# Create a vectorized version of our original circuit.
vcircuit = jax.vmap(circuit)
# Now, we call the ``vcircuit`` with multiple parameters at once and get back a
# batch of expectations.
# This examples runs 3 quantum circuits in parallel.
batch_params = jnp.array([1.02, 0.123, -0.571])
batched_results = vcircuit(batch_params)
print(f"Batched result: {batched_results}")
And I get the expected results
Batching and Evolutionary Strategies
------------------------------------
Batched result: [0.52336595 0.99244503 0.84136092]
I can’t figure out why this would work, any insight is helpful!