Hello everyone, I have a pennylane quantum neural network that accepts 2 values as inputs and outputs 1 value. Now I have a dataset that that has the dimensions N X 2. I want to apply this quantum neural network simultaneously to those N array entries such that it returns an array of N elements, how can I vectorize my circuit to do that.
Essentially how can I apply my circuit element wise to an array without using loops, maps, list comprehension. Since all of these are very slow as I scale my samples.
Hi @Smayan_Gupta , welcome to the Forum!
You might want to follow the example in our demo on using JAX with PennyLane.
I’ve updated the code in that Demo to show how you can create a vectorized version of the circuit over the inputs instead of the parameters.
I hope this helps!
import pennylane as qml
import jax
import jax.numpy as jnp
# QNode
dev = qml.device('default.qubit',wires=2)
@qml.qnode(dev, interface="jax")
def circuit(inputs, params):
# Encode inputs
qml.RX(inputs[0], wires=0)
qml.RX(inputs[1], wires=1)
# Ansatz
qml.RY(params[0], wires=0)
qml.RY(params[1], wires=1)
qml.CNOT(wires=[0, 1])
return qml.expval(qml.PauliZ(0))
inputs = jnp.array([1.02, 0.123])
params = jnp.array([2.02, 2.123])
print(f"Result: {repr(circuit(inputs,params))}")
# Create a vectorized version of our original circuit.
# in_axes indicates which array axis to map over for each argument.
# In this case we map over axis 0 for the first argument (inputs) and we decide
# not to map over the `params` argument.
vcircuit = jax.vmap(circuit,in_axes=(0,None))
# Now, we call the ``vcircuit`` with multiple inputs at once and get back a
# batch of expectations.
# This example runs 3 quantum circuits in parallel.
batch_inputs = jnp.array([[1.02, 0.123],[1.02, 0.123],[1.02, 0.123]])
batched_results = vcircuit(batch_inputs,params)
print(f"Batched result: {batched_results}")