Error using Operators + Jax + Shots

Dear Pennylane team,
I want to report an error (a potential bug report).

I want to train my quantum circuit using jax AND shots.
While everything works perfectly whitout shots, the following error appears when turning them on:

JaxStackTraceBeforeTransformation: TypeError: cannot reshape array of shape () (size 1) into shape (6,) (size 6)

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

TypeError                                 Traceback (most recent call last)
Cell In[73], line 3
      1 weights_init = jax.random.uniform(key, (2,3),  minval=0.0, maxval=jnp.pi / 2)
      2 wights_2 = jnp.array([0.1])
----> 3 optimization_jit(weights_init, True)

    [... skipping hidden 14 frame]

Cell In[72], line 40, in optimization_jit(angles_one, print_training)
     37 loss_history = jnp.zeros(training_iterations_adam)
     39 args = (angles_one, opt_state, print_training, loss_history)
---> 40 angles_one, _, _,loss_history = jax.lax.fori_loop(0, training_iterations_adam, update_step_jit, args)
     42 return angles_one, loss_history

    [... skipping hidden 13 frame]

Cell In[72], line 16, in update_step_jit(i, args)
     13 angles_one, opt_state, print_training, loss_history = args
     15 # Compute loss and gradients w.r.t. angle_vector
---> 16 loss_val, grads = jax.value_and_grad(cost_overlap, argnums=0)(angles_one)
     18 # Apply optimization step
     19 updates, opt_state = opt1.update(grads, opt_state)

    [... skipping hidden 65 frame]

File ~\Desktop\TrainingProject\jax-env\lib\site-packages\pennylane\workflow\interfaces\jax_jit.py:210, in _execute_and_compute_jvp(tapes, execute_fn, jpc, device, primals, tangents)
    207 jac_struct = tuple(_jac_shape_dtype_struct(t, device) for t in tapes.vals)
    208 results, jacobians = jax.pure_callback(wrapper, (res_struct, jac_struct), primals[0])
--> 210 jvps = _compute_jvps(jacobians, tangents_trainable, tapes.vals)
    212 return results, jvps

File ~\Desktop\TrainingProject\jax-env\lib\site-packages\pennylane\workflow\jacobian_products.py:72, in _compute_jvps(jacs, tangents, tapes)
     70         jvps.append(tuple(f[multi](dx, j) for j in jac))
     71     else:
---> 72         jvps.append(f[multi](dx, jac))
     73 return tuple(jvps)

File ~\Desktop\TrainingProject\jax-env\lib\site-packages\pennylane\gradients\jvp.py:179, in compute_jvp_single(tangent, jac)
    177     new_shape = shape[: len(shape) - first_tangent_ndim] + (tangent_size,)
    178     jac = qml.math.cast(qml.math.convert_like(jac, tangent), tangent.dtype)
--> 179     jac = qml.math.reshape(jac, new_shape)
    180     return qml.math.tensordot(jac, tangent, [[-1], [0]])
    182 tangent_ndims = [getattr(t, "ndim", 0) for t in tangent]

File ~\Desktop\TrainingProject\jax-env\lib\site-packages\autoray\autoray.py:81, in do(fn, like, *args, **kwargs)
     79 backend = _choose_backend(fn, args, kwargs, like=like)
     80 func = get_lib_fn(backend, fn)
---> 81 return func(*args, **kwargs)

    [... skipping hidden 3 frame]

File ~\Desktop\TrainingProject\jax-env\lib\site-packages\jax\_src\numpy\array_methods.py:470, in _compute_newshape(arr, newshape)
    467 else:
    468   if (all(isinstance(d, int) for d in (*arr.shape, *newshape)) and
    469       arr.size != math.prod(newshape)):
--> 470     raise TypeError(f"cannot reshape array of shape {arr.shape} (size {arr.size}) "
    471                     f"into shape {orig_newshape} (size {math.prod(newshape)})")
    472 return tuple(-core.divide_shape_sizes(arr.shape, newshape)
    473              if core.definitely_equal(d, -1) else d for d in newshape)

TypeError: cannot reshape array of shape () (size 1) into shape (6,) (size 6)

The problem seems to be caused by using a self-defined Operation with trainable gates (using the same Operation with fixed gates does not cause the error).

Here, you find a minimal example that causes the error for me:

class VariationalNetwork(Operation):
    num_wires = AnyWires
    grad_method = "A" # parameter shift differentiation

    def __init__(self, angle_array, wires, depth=None, id=None):

        # depth is not trainable but influences the action of the operator,
        # which is why we define it to be a hyperparameter
        self._hyperparameters = {
            "depth": depth
        }

        # The parent class expects all trainable parameters to be fed as positional
        # arguments, and all wires acted on fed as a keyword argument.
        # The id keyword argument allows users to give their instance a custom name.
        super().__init__(angle_array, wires=wires, id=id) # calls the init of the operator class

    @staticmethod
    def compute_decomposition(angle_array, wires, depth):
        op_list = []

        # working with the angle array
        for d in range(depth):
            start = 0 if d % 2 == 0 else 1  # Toggle starting index
            for w in range(start, len(wires) - 1, 2):
                op_list.append(qml.RY(jnp.array(angle_array[0,d]), wires=[wires[w]]))
                op_list.append(qml.RY(jnp.array(angle_array[1,d]), wires=[wires[w]]))
        return op_list

key = jax.random.PRNGKey(2)

@jax.jit
def train_one_step(angles_one, key=key):
    
    dev = qml.device("default.qubit", wires=3, shots=100) #, seed=key,seed=key, seed=key,  seed=key,  
    
    @qml.qnode(dev, interface="jax")
    def circuit(angles_one):
        VariationalNetwork(angles_one, wires=[0,1], depth=3)
        return qml.expval(qml.PauliZ(0)) 

    #fig, ax = qml.draw_mpl(circuit)(angles_one) #, level = "device"
    #plt.show()
    return circuit(angles_one)

lr = 0.001 #0.0362      # learning rate adam
training_iterations_adam = 100
opt1 = optax.adam(learning_rate=lr)


@jax.jit
def cost_overlap(angles_one):
    return 1 - train_one_step(angles_one)

def update_step_jit(i, args):
    angles_one, opt_state, print_training, loss_history = args

    # Compute loss and gradients w.r.t. angle_vector
    loss_val, grads = jax.value_and_grad(cost_overlap, argnums=0)(angles_one)

    # Apply optimization step
    updates, opt_state = opt1.update(grads, opt_state)
    (angles_one) = optax.apply_updates(angles_one, updates)
    loss_history = loss_history.at[i].set(loss_val)  # Store the loss

    # Debug printing every 5 steps if print_training is True
    def print_fn(_):
        jax.debug.print("Step: {i}  Loss: {loss_val}", i=i, loss_val=loss_val)

        #jax.debug.print("Step: {i}  angles old: {angles_old}", i=i, angles_old=angles_old)
        #jax.debug.print("Step: {i}  angles_new: {angles_new}", i=i, angles_new=angles_new)

    jax.lax.cond((i % 10 == 0) & print_training, print_fn, lambda _: None, operand=None)

    return (angles_one, opt_state, print_training, loss_history)

@jax.jit
def optimization_jit(angles_one, print_training=False):
    opt_state = opt1.init(angles_one)
    loss_history = jnp.zeros(training_iterations_adam)

    args = (angles_one, opt_state, print_training, loss_history)
    angles_one, _, _,loss_history = jax.lax.fori_loop(0, training_iterations_adam, update_step_jit, args)

    return angles_one, loss_history

weights_init = jax.random.uniform(key, (2,3),  minval=0.0, maxval=jnp.pi / 2)
wights_2 = jnp.array([0.1])
optimization_jit(weights_init, True)

Note: The error only appears for shots!=None. It also works if I only call the cost function. The error is only raised when used within the jax training routine.
It can be avoided by not defining an operation. I can do that. but I gets a bit anoying when plotting large quantum circuits.

In general it feels like using shots slows the code down a lot. Is there any possibility to mitigate that?

Best regards,
Pia

Hi @Pia ,

Using JAX with shots and samples requires changing a couple things in the code.

You can find these guidelines in the Randomness section of the docs for the JAX interface.

I would recommend testing this for a small example, to see if these guidelines solve the problem, before going to a more complicated example like the one you shared.

I hope this helps!

Dear Catalina,
Thanks for the quick reply.
I already read this tutorial beforehand and implemented very simple circuits using default.qubit + seed: that worked fine.

However, as soon as I switch to lightning.qubit, I get an error if I use the seed with the jax-key and I can’t find any specific tutorial/question/documentation on how to use this toegther.
Further, I don’t want to use samples (if possible) but directly compute the expectation value but with shots (or is this a problem)?

My probelsm are specifically:

  1. Using Jax + Operations + Shots together raises an error
  2. Using Jax + lightning.qubit + shots + use jax seed raises an error
  3. Everything that works gets super slow as soon as I switch to shots!=None.

Are there any more specific ressoureces for these problems?
Best regards,
Pia

Hi @Pia ,

Thanks for the additional details.

I think there may be several things happening here.

  1. Jitting + conditional gates is probably not a great idea since jitting needs some things to stay static. This is just a hypothesis though, we’d need to test it to confirm.
  2. Lightning.qubit has different default behaviour and things built under the hood compared to default.qubit. How many qubits are you expecting to need? If it’s less than 15 then it’s probably best to keep using default.qubit.
  3. Do you need jitting? You can use JAX without jitting, so if you don’t need jitting you can gain some flexibility and potentially avoid problems.
  4. Using expectation value with shots should not be the problem.
  5. Everything getting slow when using shots might indicate that a different process is happening under the hood by default, for example a slower differentiation method. You can try setting diff_method=adjointto see if this helps, although it may not be compatible with all of the moving pieces you have here.
  6. Finally, I noticed that the origin for the issue has to do with something related to the jacobian. I wonder if setting max_diff=2 as a QNode argument could help.

I know this is a lot to try. If this doesn’t work please let me know what your end goal is, and what is “nice to have”. I know that the suggestion to use JAX, jitting, lightning and adjoint are general recommendations to get speedups, but they’re not always the right tool for the task. By getting more context into what you’re looking to do, maybe we can figure out a minimal example that works for what you need!

I hope this helps.

Hi,
thanks for the detailed response.
I circumvenetd the arising problems now by implementing my own version of sampling and noisy gradients. That way it runs much faster for me and I can keep my Operations.
Best,
Pia

Thanks for confirming Pia.

If you’re ok with sharing your implementation feel free to share it here. It might help others having similar problems, and it might help us find alternative workflows that could eventually be integrated into PennyLane.

If you want to keep your code private that’s totally ok too.

In any case I’m glad that you were able to circumvent your problems!