Dear Pennylane team,
I want to report an error (a potential bug report).
I want to train my quantum circuit using jax AND shots.
While everything works perfectly whitout shots, the following error appears when turning them on:
JaxStackTraceBeforeTransformation: TypeError: cannot reshape array of shape () (size 1) into shape (6,) (size 6)
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
TypeError Traceback (most recent call last)
Cell In[73], line 3
1 weights_init = jax.random.uniform(key, (2,3), minval=0.0, maxval=jnp.pi / 2)
2 wights_2 = jnp.array([0.1])
----> 3 optimization_jit(weights_init, True)
[... skipping hidden 14 frame]
Cell In[72], line 40, in optimization_jit(angles_one, print_training)
37 loss_history = jnp.zeros(training_iterations_adam)
39 args = (angles_one, opt_state, print_training, loss_history)
---> 40 angles_one, _, _,loss_history = jax.lax.fori_loop(0, training_iterations_adam, update_step_jit, args)
42 return angles_one, loss_history
[... skipping hidden 13 frame]
Cell In[72], line 16, in update_step_jit(i, args)
13 angles_one, opt_state, print_training, loss_history = args
15 # Compute loss and gradients w.r.t. angle_vector
---> 16 loss_val, grads = jax.value_and_grad(cost_overlap, argnums=0)(angles_one)
18 # Apply optimization step
19 updates, opt_state = opt1.update(grads, opt_state)
[... skipping hidden 65 frame]
File ~\Desktop\TrainingProject\jax-env\lib\site-packages\pennylane\workflow\interfaces\jax_jit.py:210, in _execute_and_compute_jvp(tapes, execute_fn, jpc, device, primals, tangents)
207 jac_struct = tuple(_jac_shape_dtype_struct(t, device) for t in tapes.vals)
208 results, jacobians = jax.pure_callback(wrapper, (res_struct, jac_struct), primals[0])
--> 210 jvps = _compute_jvps(jacobians, tangents_trainable, tapes.vals)
212 return results, jvps
File ~\Desktop\TrainingProject\jax-env\lib\site-packages\pennylane\workflow\jacobian_products.py:72, in _compute_jvps(jacs, tangents, tapes)
70 jvps.append(tuple(f[multi](dx, j) for j in jac))
71 else:
---> 72 jvps.append(f[multi](dx, jac))
73 return tuple(jvps)
File ~\Desktop\TrainingProject\jax-env\lib\site-packages\pennylane\gradients\jvp.py:179, in compute_jvp_single(tangent, jac)
177 new_shape = shape[: len(shape) - first_tangent_ndim] + (tangent_size,)
178 jac = qml.math.cast(qml.math.convert_like(jac, tangent), tangent.dtype)
--> 179 jac = qml.math.reshape(jac, new_shape)
180 return qml.math.tensordot(jac, tangent, [[-1], [0]])
182 tangent_ndims = [getattr(t, "ndim", 0) for t in tangent]
File ~\Desktop\TrainingProject\jax-env\lib\site-packages\autoray\autoray.py:81, in do(fn, like, *args, **kwargs)
79 backend = _choose_backend(fn, args, kwargs, like=like)
80 func = get_lib_fn(backend, fn)
---> 81 return func(*args, **kwargs)
[... skipping hidden 3 frame]
File ~\Desktop\TrainingProject\jax-env\lib\site-packages\jax\_src\numpy\array_methods.py:470, in _compute_newshape(arr, newshape)
467 else:
468 if (all(isinstance(d, int) for d in (*arr.shape, *newshape)) and
469 arr.size != math.prod(newshape)):
--> 470 raise TypeError(f"cannot reshape array of shape {arr.shape} (size {arr.size}) "
471 f"into shape {orig_newshape} (size {math.prod(newshape)})")
472 return tuple(-core.divide_shape_sizes(arr.shape, newshape)
473 if core.definitely_equal(d, -1) else d for d in newshape)
TypeError: cannot reshape array of shape () (size 1) into shape (6,) (size 6)
The problem seems to be caused by using a self-defined Operation with trainable gates (using the same Operation with fixed gates does not cause the error).
Here, you find a minimal example that causes the error for me:
class VariationalNetwork(Operation):
num_wires = AnyWires
grad_method = "A" # parameter shift differentiation
def __init__(self, angle_array, wires, depth=None, id=None):
# depth is not trainable but influences the action of the operator,
# which is why we define it to be a hyperparameter
self._hyperparameters = {
"depth": depth
}
# The parent class expects all trainable parameters to be fed as positional
# arguments, and all wires acted on fed as a keyword argument.
# The id keyword argument allows users to give their instance a custom name.
super().__init__(angle_array, wires=wires, id=id) # calls the init of the operator class
@staticmethod
def compute_decomposition(angle_array, wires, depth):
op_list = []
# working with the angle array
for d in range(depth):
start = 0 if d % 2 == 0 else 1 # Toggle starting index
for w in range(start, len(wires) - 1, 2):
op_list.append(qml.RY(jnp.array(angle_array[0,d]), wires=[wires[w]]))
op_list.append(qml.RY(jnp.array(angle_array[1,d]), wires=[wires[w]]))
return op_list
key = jax.random.PRNGKey(2)
@jax.jit
def train_one_step(angles_one, key=key):
dev = qml.device("default.qubit", wires=3, shots=100) #, seed=key,seed=key, seed=key, seed=key,
@qml.qnode(dev, interface="jax")
def circuit(angles_one):
VariationalNetwork(angles_one, wires=[0,1], depth=3)
return qml.expval(qml.PauliZ(0))
#fig, ax = qml.draw_mpl(circuit)(angles_one) #, level = "device"
#plt.show()
return circuit(angles_one)
lr = 0.001 #0.0362 # learning rate adam
training_iterations_adam = 100
opt1 = optax.adam(learning_rate=lr)
@jax.jit
def cost_overlap(angles_one):
return 1 - train_one_step(angles_one)
def update_step_jit(i, args):
angles_one, opt_state, print_training, loss_history = args
# Compute loss and gradients w.r.t. angle_vector
loss_val, grads = jax.value_and_grad(cost_overlap, argnums=0)(angles_one)
# Apply optimization step
updates, opt_state = opt1.update(grads, opt_state)
(angles_one) = optax.apply_updates(angles_one, updates)
loss_history = loss_history.at[i].set(loss_val) # Store the loss
# Debug printing every 5 steps if print_training is True
def print_fn(_):
jax.debug.print("Step: {i} Loss: {loss_val}", i=i, loss_val=loss_val)
#jax.debug.print("Step: {i} angles old: {angles_old}", i=i, angles_old=angles_old)
#jax.debug.print("Step: {i} angles_new: {angles_new}", i=i, angles_new=angles_new)
jax.lax.cond((i % 10 == 0) & print_training, print_fn, lambda _: None, operand=None)
return (angles_one, opt_state, print_training, loss_history)
@jax.jit
def optimization_jit(angles_one, print_training=False):
opt_state = opt1.init(angles_one)
loss_history = jnp.zeros(training_iterations_adam)
args = (angles_one, opt_state, print_training, loss_history)
angles_one, _, _,loss_history = jax.lax.fori_loop(0, training_iterations_adam, update_step_jit, args)
return angles_one, loss_history
weights_init = jax.random.uniform(key, (2,3), minval=0.0, maxval=jnp.pi / 2)
wights_2 = jnp.array([0.1])
optimization_jit(weights_init, True)
Note: The error only appears for shots!=None. It also works if I only call the cost function. The error is only raised when used within the jax training routine.
It can be avoided by not defining an operation. I can do that. but I gets a bit anoying when plotting large quantum circuits.
In general it feels like using shots slows the code down a lot. Is there any possibility to mitigate that?
Best regards,
Pia