QML training using shot-based simulator

I’d like to run QML training using shot-based simulation, using:
dev = qml.device("default.qubit", wires=n_wire, shots=shots)
Based on your Iris tutorial I made an abbreviated, no-graphs version of the binary iris classifier which works with state vector simulator and NesterovMomentumOptimizer.
tut05_nestorovOpt_irisBinary_stateVect.py

I understand, that for the shot based simulation w/ PennyLane I need to change output in QNode to be: return qml.probs(wires=[0, 1])
And that the Nestorov is gradient-based optimizer, so I either need to provide the gradient using the parameter-shift rule or change the optimizer to not require gradients, say from JAXopt. But the devil is in the details.

Can you help me in modifying my code?

Hey @Jan_Balewski,

Can you provide some code here that shows what you’re trying to accomplish? I’m not sure I understand exactly what you’re trying to do. Thanks!

Sorry, for not being more clear. This is the code I have trouble with:

It runs with the state-vector device and gives me 72% classification accuracy accuracy for a 2D input data which are non-linearly separable :
dev = qml.device('default.qubit', wires=n_qubits) # works
But I want to run it with the shot-based simulator:
dev = qml.device('default.qubit', wires=n_qubits, shots=5000)
And I see the error:
Computing the gradient of broadcasted tapes with the parameter-shift rule gradient transform is currently not supported.
I’m hoping for advice how to run optimally PennyLane for such device. I left code in the ‘crashing’ state
Here are the training & validation data: quantumMind/PennyLane/notebooks/data/circ2d_bin.npy at main · balewski/quantumMind · GitHub

Hey @Jan_Balewski, thanks! I think it would be best for you to reduce your code down to something much simpler so that we can better understand where the error is stemming from. This is a good debugging exercise, too!

Can you create a small, toy example problem that uses the same optimizer? The minimal example doesn’t even have to be something that makes sense; it could have dummy data being fed into a dummy variational circuit that you try to optimize once with finite shots and without shots.

Functionally stripping your code base down like this is a really good debugging skill to develop :slight_smile:. Let me know where you get to and we’ll try to lend a hand!

Right, I should have removed the bells and whistles. The code below is making some progress toward separating two non-lineray separable classes, the cost is going down - if I use
dev = qml.device('default.qubit', wires=n_qubits)
This is the circuit and training over 10 epochs:

0: ──RY(1.47)──||──RX(0.11)──RY(0.02)──RZ(0.10)─╭●────╭X─┤     
1: ──RY(2.74)──||──RX(0.09)──RY(0.17)──RZ(0.12)─╰X─╭●─│──┤     
2: ────────────||──RX(0.13)──RY(0.16)──RZ(0.11)────╰X─╰●─┤  <Z> 

epoch=0, cost=0.617
epoch=1, cost=0.614
epoch=2, cost=0.610
epoch=3, cost=0.606
epoch=4, cost=0.603
epoch=5, cost=0.601
epoch=6, cost=0.600
epoch=7, cost=0.600
epoch=8, cost=0.600
epoch=9, cost=0.600

But if I change device to be shot based training crashes. Below is the ‘crashing’ version:

import numpy as cnp
import pennylane as qml
from pennylane import numpy as np
from pennylane.optimize import NesterovMomentumOptimizer

n_sampl = 1000  ; n_feature=2; n_qubits=3; layers=1; epochs=10

#dev = qml.device('default.qubit', wires=n_qubits)  # works
dev = qml.device('default.qubit', wires=n_qubits, shots=5000) #  CRASHING

#.... input data
X = cnp.random.uniform(-1, 1, size=(n_sampl, n_feature))
Y = cnp.where(X[:, 0] * X[:, 1] > 0, 1, -1) # Compute labels
#... trainable params
params = 0.2 * np.random.random(size=(layers, n_qubits,3))

@qml.qnode(dev)  
def circuit(params,x):
    a=np.arccos(x)  # encoding of the input to [0,2pi]
    qml.RY(a[0], wires=0)
    qml.RY(a[1], wires=1)
    for layer in range(layers): # EfficientSU2 ansatz
        qml.Barrier()
        for qubit in range(n_qubits):
            qml.RX(params[layer, qubit, 0], wires=qubit)
            qml.RY(params[layer, qubit, 1], wires=qubit)
            qml.RZ(params[layer, qubit, 2], wires=qubit)        
        for qubit in range(n_qubits):
            qml.CNOT(wires=[qubit, (qubit + 1) % n_qubits])
    return qml.expval(qml.PauliZ(2))

print(qml.draw(circuit, decimals=2)(params,X[0]), '\n')

#... classical ML utility func
def cost_function( params, X, Y):  # vectorized code
    pred = circuit(params,  X.T)
    cost = np.mean((Y - qml.math.stack(pred)) ** 2)
    return cost

#... run optimizer
opt = NesterovMomentumOptimizer(0.10, momentum=0.90)
for it in range(epochs):
     params = opt.step(lambda p: cost_function(p, X, Y), params)
     cost= cost_function(params, X, Y)
     print('epoch=%d, cost=%.3f'%(it,cost))

The error is:

`File "/usr/local/lib/python3.10/dist-packages/pennylane/gradients/gradient_transform.py", line 111, in assert_no_tape_batching
NotImplementedError: Computing the gradient of broadcasted tapes with the parameter-shift rule gradient transform is currently not supported. See #4462 for details.`

Thanks @Jan_Balewski!

This is a known bug ([BUG] Differentiation of broadcasted QNode via gradient transforms · Issue #4462 · PennyLaneAI/pennylane · GitHub) but I think this might be showing another consequence of this bug :thinking:. I’ll see what the status is here!

I did not know how to find URL for the ticket, based only on its number.
I tried replacement for diff_method, but this combination also fails

dev = qml.device('default.qubit', wires=n_qubits,shots=5000)
@qml.qnode(dev,diff_method="backprop")

The new error is:

File "/usr/local/lib/python3.10/dist-packages/pennylane/workflow/qnode.py", line 742, in _validate_backprop_method
    raise qml.QuantumFunctionError("Backpropagation is only supported when shots=None.")
pennylane.QuantumFunctionError: Backpropagation is only supported when shots=None.

I think when shots are used you need to fire more circuits with this slightly varied params and compute gradient estimate . There is no magic.

Yep! Am looking into this :slight_smile:. That issue I linked was “resolved”, but there might still be a problem. Am chatting with some devs here and will get back to you as soon as I can.

Is there any solution to this issue in sight?

Hey @Jan_Balewski,

Apologies for the wait! I’ll try and get an update for you today.

I’ve managed to make some progress using opt=qml.SPSAOptimizer(maxiter=epochs) which does work with the shot-based device simulator. However, it appears to be less efficient compared to using Nesterov or Adam optimizers, which would be my preferred choices.

Hey @Jan_Balewski,

Our dev team is a little busy at the moment, and tomorrow there’s a holiday in Canada. So, looking like we can get back to you early next week :slight_smile:. Rest assured that it’s on our radar and that we’re looking into it!

Hi @Jan_Balewski,

I think I understand the problem above: In your setup, you combined a broadcasted (vectorized over X) QNode with an unspecified differentiation method. This means that PennyLane will decide for the best available diff_method on its own. For shots=None, it will pick "backprop", as it is usually the most suitable method in this scenario. For shots!=None, there is no backprop support, so PennyLane will pick "parameter-shift" instead. Now, unfortunately, "parameter-shift" (and "finite-diff" and "hadamard" as well, for that matter) do not support broadcasted QNodes currently, even if the broadcasting is not in the parameters with respect to which we compute the derivatives. The result is that your code runs smoothly with shots=None (i.e. diff_method="backprop") but not with shots!=None (i.e. diff_method="parameter-shift").

So what could you do to make this work?

  1. Skip the broadcasting. It will be (much) slower, probably, but it will run as expected. Instead of feeding the vectorized X into the QNode, use a for loop over the entries in X.
  2. Use a different optimizer as you already mentioned yourself. However, this may not yield to the desired convergence quality, as you also found.
  3. There might be a way to make your code work by vectorizing it with, e.g. jax.vmap (or the corresponding techniques built into Pytorch/Tensorflow). This requires usage of jax for the autodifferentiation in the first place, and I believe you may have to switch to a jax implementation of your optimizer of choice (standard optimizers like Adam are available of course). Would this be an option?

Optimally, we’d be supporting broadcasted QNodes with all these differentiation methods. Possibly, we can make this work for restricted scenarios like yours, where the broadcasting is not coming from differentiated parameters. I will look into it but I can not promise a fast patch for this, unfortunately.

I hope you find a way around this incompatibility.

Happy coding!

Follow-up: I drafted a PR that should allow your code to run: #5452.
You can clone the repository, check out this PR’s branch and install pennylane from the directory to try this out now, or wait for the release that will contain this patch :slight_smile:

As far as I can tell, your code will run successfully with this patch, as long as you replace the np.arccos call by cnp.arccos. As a rule of thumb, it’s a good idea to move classical processing outside of QNodes to avoid problems with trainability (In your code, np.arccos makes the vanilla numpy array into a trainable autograd numpy array, causing some funky issue :upside_down_face: )

Hope this helps! Happy coding :slight_smile:

Thank you Guys for analyzing my code and identifying the limitations. Although I prefer vectorized numpy code because it performs better at scale in this case I’ll step back and use a simpler scalar code.

Follow up. I tested it and it is not good. I applied both suggested changes for the above example and reduced number of samples from 1k to 300 to not wait that long. Below is comparison of 1 step execution time for 4 combinations of device type and circuit execution method

dev: state based, circuit execution:  vectorized  --> 0.06 sec/step
dev: state based, circuit execution:  list compreh.  --> 3.6 sec/step  ( 60 x slower)
dev: shot-based, circuit execution:  list compreh. --> 18 sec/step ( 300 x slower)
dev: shot-based, circuit execution:  vectorized  -> crash, as expected

It does not feel as practical solution to not do vectorized circuit execution - my computing resources are finite. The key issues is 60x slow down , I would live with additional 5x cost when using shot-based simulator instead of state-vector simulator.
Below is the code as-used

import numpy as cnp
import pennylane as qml
from pennylane import numpy as np
from pennylane.optimize import NesterovMomentumOptimizer
from time import time

n_sampl = 300  ; n_feature=2; n_qubits=3; layers=1; epochs=10

#dev = qml.device('default.qubit', wires=n_qubits)  # works
dev = qml.device('default.qubit', wires=n_qubits, shots=1000) 

#.... input data
X = cnp.random.uniform(-1, 1, size=(n_sampl, n_feature))
Y = cnp.where(X[:, 0] * X[:, 1] > 0, 1, -1) # Compute labels
#... trainable params
params = 0.2 * np.random.random(size=(layers, n_qubits,3))

@qml.qnode(dev)  
def circuit(params,x):
    a=cnp.arccos(x)  # encoding of the input to [0,2pi]
    qml.RY(a[0], wires=0)
    qml.RY(a[1], wires=1)
    for layer in range(layers): # EfficientSU2 ansatz
        qml.Barrier()
        for qubit in range(n_qubits):
            qml.RX(params[layer, qubit, 0], wires=qubit)
            qml.RY(params[layer, qubit, 1], wires=qubit)
            qml.RZ(params[layer, qubit, 2], wires=qubit)        
        for qubit in range(n_qubits):
            qml.CNOT(wires=[qubit, (qubit + 1) % n_qubits])
    return qml.expval(qml.PauliZ(2))

print(qml.draw(circuit, decimals=2)(params,X[0]), '\n')

#... classical ML utility func
def cost_function( params, X, Y):  # vectorized code
    predL= [ circuit(params, x1) for x1 in X ]  # list comprehension
    #predL = circuit(params,  X.T)  # vectorized execution
    cost = np.mean((Y - qml.math.stack(predL)) ** 2)
    return cost

#... run optimizer
opt = NesterovMomentumOptimizer(0.10, momentum=0.90)
for it in range(epochs):
    T0=time()
    params = opt.step(lambda p: cost_function(p, X, Y), params)
    durT=time()-T0
    print('epoch=%d,   %.2f sec/step'%(it,durT))

My software stack:

pip3 list |grep Pen
PennyLane                       0.35.1
PennyLane-Cirq                  0.34.0
PennyLane_Lightning             0.35.1
PennyLane_Lightning_GPU         0.35.1
PennyLane-qiskit                0.35.1
PennyLane-SF                    0.29.0

Hi @Jan_Balewski
Just to make sure, did you see my second reply? Does it help you?

Would using JAX be an option. It is supported throughout PennyLane (except for the optimizers, as mentioned, but those are readily available in packages like optax or jaxopt).

Since I have very little experience with optax or jaxopt I’m not sure. In general I’m open to use them for vectorized circuit execution. I realize I’m calling here a PennyLane helpline and you have other thing to do. Would you have time to guide me how to change the code posted here to adopt it (to prove the speed is good) or it is a homework for me?
But my ultimate goal is to run this type of optimization problems on much larger circuits, which will have many (20-40) feed-forward operations ( i.e. mid-circuit measurements and gates conditioned on those measurement), the circuit size will be ~20 qubits and ~100 entangling gates, and the device will be real hardware from IBM (so transpiler will further render the initial circuit, result will be nor continuous, shot-based). If I change to optax or jaxopt I want to be sure I can still do what I’m planning.

It’s hard to anticipate whether the whole workflow can be ported over to hardware etc, but I don’t see a categorical reason why it should be impossible.
However, as broadcasting/batching will not really work on quantum hardware anyways, maybe it is not worth the work to switch? That’s ultimately your decision. :slight_smile:

If you do want to switch to JAX, PennyLane supports it properly and you mostly will have to switch to an optimizer from the mentioned packages, there is not too much to it. If you use qml.math instead of cnp or np, PennyLane will apply the correct classical processing functions for you :slight_smile:
Also, the interface will be recognized automatically by the QNode, so that you don’t have to worry about the QNode code too much either.
There also is a very recent how-to guide for PennyLane + JAX + JAXopt here and for PennyLane + JAX + Optax here.
I hope this helps, in case you do want to switch to JAX.

As an alternative, there is the patch I mentioned here, which should allow you to do exactly what you were doing in the first place, with a single line of code changed.

Thanks for pointing me to this JAX+Optax tutorial. I have modified my code to use both.
It works with dev = qml.device('default.qubit', wires=n_qubits)
but crashes with dev = qml.device('default.qubit', wires=n_qubits, shots=1000)
The error is ValueError: probabilities do not sum to 1. Below is the full dump

M: verify code sanity, X: (300, 2)
Traceback (most recent call last):
  File "/PennyLane/toys/./toy_opt_speed_jax_optax.py", line 62, in <module>
    val=loss_fn(params, X,Y)
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py", line 248, in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py", line 143, in _python_pjit_helper
    out_flat = pjit_p.bind(*args_flat, **params)
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/core.py", line 2727, in bind
    return self.bind_with_trace(top_trace, args, params)
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/core.py", line 423, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/core.py", line 913, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py", line 1409, in _pjit_call_impl
    return xc._xla.pjit(
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py", line 1392, in call_impl_cache_miss
    out_flat, compiled = _pjit_call_impl_python(
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py", line 1348, in _pjit_call_impl_python
    return compiled.unsafe_call(*args), compiled
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/profiler.py", line 336, in wrapper
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/pxla.py", line 1201, in __call__
    results = self.xla_executable.execute_sharded(
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Generated function failed: CpuCallback error: ValueError: probabilities do not sum to 1

At:
  numpy/random/_generator.pyx(828): numpy.random._generator.Generator.choice
  /usr/local/lib/python3.10/dist-packages/pennylane/devices/qubit/sampling.py(443): <listcomp>
  /usr/local/lib/python3.10/dist-packages/pennylane/devices/qubit/sampling.py(443): sample_state
  /usr/local/lib/python3.10/dist-packages/pennylane/devices/qubit/sampling.py(297): _measure_with_samples_diagonalizing_gates
  /usr/local/lib/python3.10/dist-packages/pennylane/devices/qubit/sampling.py(198): measure_with_samples
  /usr/local/lib/python3.10/dist-packages/pennylane/devices/qubit/simulate.py(200): measure_final_state
  /usr/local/lib/python3.10/dist-packages/pennylane/devices/qubit/simulate.py(263): simulate
  /usr/local/lib/python3.10/dist-packages/pennylane/devices/default_qubit.py(554): <genexpr>

Attached is the reproducer:

import numpy as cnp
import pennylane as qml
import jax
from jax import numpy as jnp
import optax
from time import time

n_sampl = 300
n_feature=2; n_qubits=3; layers=1; steps=10

#dev = qml.device('default.qubit', wires=n_qubits)  # works
dev = qml.device('default.qubit', wires=n_qubits, shots=1000)  # crashes 

#.... input data
Xu= cnp.random.uniform(-1, 1, size=(n_sampl, n_feature) )
Xa=cnp.arccos(Xu)
X =  jnp.array(Xa )
Y = jnp.where(Xu[:, 0] * Xu[:, 1] > 0, 1, -1) # Compute labels
#... trainable params
params = 0.2 * jnp.array( cnp.random.random(size=(layers, n_qubits,3)) )

@qml.qnode(dev)  
def circuit(params,x):
    qml.RY(x[0], wires=0)
    qml.RY(x[1], wires=1)
    for layer in range(layers): # EfficientSU2 ansatz
        qml.Barrier()
        for qubit in range(n_qubits):
            qml.RX(params[layer, qubit, 0], wires=qubit)
            qml.RY(params[layer, qubit, 1], wires=qubit)
            qml.RZ(params[layer, qubit, 2], wires=qubit)        
        for qubit in range(n_qubits):
            qml.CNOT(wires=[qubit, (qubit + 1) % n_qubits])
    return qml.expval(qml.PauliZ(2))

print(qml.draw(circuit, decimals=2)(params,X[0]), '\n')

#... classical ML utility func
@jax.jit
def loss_fn( params, X, Y):  # vectorized code
    pred = circuit(params,  X.T)  # vectorized execution
    cost = jnp.mean((Y - pred)) ** 2)
    return cost

print('M: verify code sanity, X:',X.shape)
T0=time()
val=loss_fn(params, X,Y)  # <=== CRASH IS HERE
durT=time()-T0
print('elaT=%.1f sec, one loss_fn:%s'%(durT,val))

print('grad:',jax.grad(loss_fn)(params,X,Y))

#... run optimizer
opt = optax.adam(learning_rate=0.3)
opt_state = opt.init(params)

def update_step(opt, params, opt_state, data, targets):
    loss_val, grads = jax.value_and_grad(loss_fn)(params, data, targets)
    updates, opt_state = opt.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss_val

for i in range(100):
    params, opt_state, loss_val = update_step(opt, params, opt_state, X,Y)
 
    if i % 5 == 0:
        print(f"Step: {i} Loss: {loss_val}")

How shell I proceed now to run optimization on shot-based device?