Shadow_expval batch not working

Hello Pennylane team,

I am trying to write a qml algorithm. I wish to estimate the expectation values of many observables using classical shadow to compute a loss function. I have found the function shadow_expval, which nicely allows me to batch over the observables. However, I failed to apply jax.vmap to batch over data inputs to the circuit. Below is a simplified version of my code and error.

observables=[qml.PauliZ(0),qml.PauliZ(1),qml.PauliZ(2)]
nq=3

dev = qml.device("default.qubit", wires=nq, shots=4096)
@qml.qnode(dev) 
def qnode(inputs, weights):
    for i in range(nq):
        qml.RX(inputs[0], wires=i)
        qml.RY(inputs[1], wires=i)
    qml.RY(weights[0], wires=0)
    qml.RY(weights[1], wires=1)
    return qml.shadow_expval(observables)

key = jax.random.PRNGKey(seed=0)
weights=jax.random.uniform(key=key, shape=(2,))
inputs=jnp.ones((16,2))
#qnode(inputs,weights)
jax.vmap(qnode, in_axes=(0,None))(inputs, weights)

jax.pure_callback failed
Traceback (most recent call last):
  File "/Users/kevinshen/Documents/evs_venv/lib/python3.9/site-packages/jax/_src/callback.py", line 79, in pure_callback_impl
    return tree_util.tree_map(np.asarray, callback(*args))
  File "/Users/kevinshen/Documents/evs_venv/lib/python3.9/site-packages/jax/_src/callback.py", line 64, in __call__
    return tree_util.tree_leaves(self.callback_func(*args, **kwargs))
  File "/Users/kevinshen/Documents/evs_venv/lib/python3.9/site-packages/pennylane/workflow/interfaces/jax_jit.py", line 168, in pure_callback_wrapper
    return _to_jax(execute_fn(new_tapes))
  File "/Users/kevinshen/Documents/evs_venv/lib/python3.9/site-packages/pennylane/workflow/execution.py", line 212, in inner_execute
    results = device.execute(transformed_tapes, execution_config=execution_config)
  File "/Users/kevinshen/Documents/evs_venv/lib/python3.9/site-packages/pennylane/devices/modifiers/simulator_tracking.py", line 30, in execute
    results = untracked_execute(self, circuits, execution_config)
  File "/Users/kevinshen/Documents/evs_venv/lib/python3.9/site-packages/pennylane/devices/modifiers/single_tape_support.py", line 32, in execute
    results = batch_execute(self, circuits, execution_config)
  File "/Users/kevinshen/Documents/evs_venv/lib/python3.9/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/Users/kevinshen/Documents/evs_venv/lib/python3.9/site-packages/pennylane/devices/default_qubit.py", line 630, in execute
    return tuple(
  File "/Users/kevinshen/Documents/evs_venv/lib/python3.9/site-packages/pennylane/devices/default_qubit.py", line 631, in <genexpr>
    _simulate_wrapper(
  File "/Users/kevinshen/Documents/evs_venv/lib/python3.9/site-packages/pennylane/devices/default_qubit.py", line 896, in _simulate_wrapper
    return simulate(circuit, **kwargs)
  File "/Users/kevinshen/Documents/evs_venv/lib/python3.9/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/Users/kevinshen/Documents/evs_venv/lib/python3.9/site-packages/pennylane/devices/qubit/simulate.py", line 384, in simulate
    return measure_final_state(
  File "/Users/kevinshen/Documents/evs_venv/lib/python3.9/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/Users/kevinshen/Documents/evs_venv/lib/python3.9/site-packages/pennylane/devices/qubit/simulate.py", line 280, in measure_final_state
    results = measure_with_samples(
  File "/Users/kevinshen/Documents/evs_venv/lib/python3.9/site-packages/pennylane/devices/qubit/sampling.py", line 255, in measure_with_samples
    measure_fn(
  File "/Users/kevinshen/Documents/evs_venv/lib/python3.9/site-packages/pennylane/devices/qubit/sampling.py", line 384, in _measure_classical_shadow
    return [mp.process_state_with_shots(state, wires, shots.total_shots, rng=rng)]
  File "/Users/kevinshen/Documents/evs_venv/lib/python3.9/site-packages/pennylane/measurements/classical_shadow.py", line 544, in process_state_with_shots
    bits, recipes = qml.classical_shadow(
  File "/Users/kevinshen/Documents/evs_venv/lib/python3.9/site-packages/pennylane/measurements/classical_shadow.py", line 412, in process_state_with_shots
    probs = (np.einsum("abc,acb->a", first_qubit_state, obs[:, active_qubit]) + 1) / 2
  File "/Users/kevinshen/Documents/evs_venv/lib/python3.9/site-packages/numpy/core/einsumfunc.py", line 1371, in einsum
    return c_einsum(*operands, **kwargs)
ValueError: operands could not be broadcast together with remapped shapes [original->remapped]: (4096,16,16)->(4096,16,16) (4096,2,2)->(4096,2,2) 
---------------------------------------------------------------------------
XlaRuntimeError                           Traceback (most recent call last)
Cell In[50], line 5
      3 inputs=jnp.ones((16,2))
      4 qnode(inputs,weights)
----> 5 jax.vmap(qnode, in_axes=(0,None))(inputs, weights)

    [... skipping hidden 3 frame]

File ~/Documents/evs_venv/lib/python3.9/site-packages/pennylane/workflow/qnode.py:1020, in QNode.__call__(self, *args, **kwargs)
   1018 if qml.capture.enabled():
   1019     return qml.capture.qnode_call(self, *args, **kwargs)
-> 1020 return self._impl_call(*args, **kwargs)

File ~/Documents/evs_venv/lib/python3.9/site-packages/pennylane/workflow/qnode.py:1008, in QNode._impl_call(self, *args, **kwargs)
   1005 self._update_gradient_fn(shots=override_shots, tape=self._tape)
   1007 try:
-> 1008     res = self._execution_component(args, kwargs, override_shots=override_shots)
   1009 finally:
   1010     if old_interface == "auto":

File ~/Documents/evs_venv/lib/python3.9/site-packages/pennylane/workflow/qnode.py:957, in QNode._execution_component(self, args, kwargs, override_shots)
    951     warnings.filterwarnings(
    952         action="ignore",
    953         message=r".*argument is deprecated and will be removed in version 0.39.*",
    954         category=qml.PennyLaneDeprecationWarning,
    955     )
    956     # pylint: disable=unexpected-keyword-arg
--> 957     res = qml.execute(
    958         (self._tape,),
    959         device=self.device,
    960         gradient_fn=self.gradient_fn,
    961         interface=self.interface,
    962         transform_program=full_transform_program,
    963         inner_transform=inner_transform_program,
    964         config=config,
    965         gradient_kwargs=self.gradient_kwargs,
    966         override_shots=override_shots,
    967         **execute_kwargs,
    968     )
    969 res = res[0]
    971 # convert result to the interface in case the qfunc has no parameters

File ~/Documents/evs_venv/lib/python3.9/site-packages/pennylane/workflow/execution.py:771, in execute(tapes, device, gradient_fn, interface, transform_program, inner_transform, config, grad_on_execution, gradient_kwargs, cache, cachesize, max_diff, override_shots, expand_fn, max_expansion, device_batch_transform, device_vjp, mcm_config)
    763 ml_boundary_execute = _get_ml_boundary_execute(
    764     interface,
    765     config.grad_on_execution,
    766     config.use_device_jacobian_product,
    767     differentiable=max_diff > 1,
    768 )
    770 if interface in jpc_interfaces:
--> 771     results = ml_boundary_execute(tapes, execute_fn, jpc, device=device)
    772 else:
    773     results = ml_boundary_execute(
    774         tapes, device, execute_fn, gradient_fn, gradient_kwargs, _n=1, max_diff=max_diff
    775     )

File ~/Documents/evs_venv/lib/python3.9/site-packages/pennylane/workflow/interfaces/jax_jit.py:264, in jax_jit_jvp_execute(tapes, execute_fn, jpc, device)
    260     raise NotImplementedError("The JAX-JIT interface doesn't support qml.counts.")
    262 parameters = tuple(tuple(t.get_parameters(trainable_only=False)) for t in tapes)
--> 264 return _execute_jvp_jit(parameters, _NonPytreeWrapper(tuple(tapes)), execute_fn, jpc, device)

    [... skipping hidden 7 frame]

File ~/Documents/evs_venv/lib/python3.9/site-packages/pennylane/workflow/interfaces/jax_jit.py:178, in _execute_wrapper_inner(params, tapes, execute_fn, _, device, is_vjp)
    172 else:
    173     # first order way of determining native parameter broadcasting support
    174     # will be inaccurate when inclusion of broadcast_expand depends on ExecutionConfig values (like adjoint)
    175     device_supports_vectorization = (
    176         qml.transforms.broadcast_expand not in device.preprocess()[0]
    177     )
--> 178 out = jax.pure_callback(
    179     pure_callback_wrapper, shape_dtype_structs, params, vectorized=device_supports_vectorization
    180 )
    181 return out

File ~/Documents/evs_venv/lib/python3.9/site-packages/jax/_src/callback.py:320, in pure_callback(callback, result_shape_dtypes, sharding, vectorized, *args, **kwargs)
    317 result_avals = tree_util.tree_map(
    318     lambda x: core.ShapedArray(x.shape, x.dtype), result_shape_dtypes)
    319 flat_result_avals, out_tree = tree_util.tree_flatten(result_avals)
--> 320 out_flat = pure_callback_p.bind(
    321     *flat_args,
    322     callback=_FlatCallback(callback, in_tree),
    323     result_avals=tuple(flat_result_avals),
    324     sharding=sharding,
    325     vectorized=vectorized,
    326 )
    327 return tree_util.tree_unflatten(out_tree, out_flat)

    [... skipping hidden 3 frame]

File ~/Documents/evs_venv/lib/python3.9/site-packages/jax/_src/callback.py:137, in pure_callback_batching_rule(args, dims, callback, sharding, vectorized, result_avals)
    133 if vectorized:
    134   result_avals = tuple(
    135       core.unmapped_aval(axis_size, core.no_axis_name, 0, aval)  # type: ignore
    136       for aval in result_avals)
--> 137   outvals = pure_callback_p.bind(
    138       *new_args,
    139       callback=callback,
    140       sharding=sharding,
    141       vectorized=vectorized,
    142       result_avals=result_avals,
    143   )
    144 else:
    145   is_batched = [d is not batching.not_mapped for d in dims]

    [... skipping hidden 14 frame]

File ~/Documents/evs_venv/lib/python3.9/site-packages/jax/_src/interpreters/pxla.py:1177, in ExecuteReplicated.__call__(self, *args)
   1174 if (self.ordered_effects or self.has_unordered_effects
   1175     or self.has_host_callbacks):
   1176   input_bufs = self._add_tokens_to_inputs(input_bufs)
-> 1177   results = self.xla_executable.execute_sharded(
   1178       input_bufs, with_tokens=True
   1179   )
   1180   result_token_bufs = results.disassemble_prefix_into_single_device_arrays(
   1181       len(self.ordered_effects))
   1182   sharded_runtime_token = results.consume_token()

XlaRuntimeError: INTERNAL: Generated function failed: CpuCallback error: Traceback (most recent call last):
  File "/Users/kevinshen/Documents/evs_venv/lib/python3.9/site-packages/pennylane/devices/default_qubit.py", line 631, in <genexpr>
  File "/Users/kevinshen/Documents/evs_venv/lib/python3.9/site-packages/pennylane/devices/default_qubit.py", line 896, in _simulate_wrapper
  File "/Users/kevinshen/Documents/evs_venv/lib/python3.9/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
  File "/Users/kevinshen/Documents/evs_venv/lib/python3.9/site-packages/pennylane/devices/qubit/simulate.py", line 384, in simulate
  File "/Users/kevinshen/Documents/evs_venv/lib/python3.9/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
  File "/Users/kevinshen/Documents/evs_venv/lib/python3.9/site-packages/pennylane/devices/qubit/simulate.py", line 280, in measure_final_state
  File "/Users/kevinshen/Documents/evs_venv/lib/python3.9/site-packages/pennylane/devices/qubit/sampling.py", line 255, in measure_with_samples
  File "/Users/kevinshen/Documents/evs_venv/lib/python3.9/site-packages/pennylane/devices/qubit/sampling.py", line 384, in _measure_classical_shadow
  File "/Users/kevinshen/Documents/evs_venv/lib/python3.9/site-packages/pennylane/measurements/classical_shadow.py", line 544, in process_state_with_shots
  File "/Users/kevinshen/Documents/evs_venv/lib/python3.9/site-packages/pennylane/measurements/classical_shadow.py", line 412, in process_state_with_shots
  File "/Users/kevinshen/Documents/evs_venv/lib/python3.9/site-packages/numpy/core/einsumfunc.py", line 1371, in einsum
ValueError: operands could not be broadcast together with remapped shapes [original->remapped]: (4096,16,16)->(4096,16,16) (4096,2,2)->(4096,2,2) 

A more fundamental question is, is jax.vmap beneficial in this case? (I have access to GPUs and saw speedup when I was using vmap on circuits with expval, but I don’t know if it will also be faster for shadow_expval.)

Hi @Kevin_Shen,

Since vmap is giving you trouble here I’d recommend just keeping qnode(inputs,weights). I’m not sure what’s the exact source for the issue so it’s probably better not to use vmap with shadow_expval for now.

Edit: Actually the issue is not specific to jax.vmap . shadow expval just doesn’t seem to work with any form of parameter broadcasting at the moment. We’ll open a bug report and look into fixing this.

Here’s the bug report for future reference: [BUG] shadow_expval doesn't work with parameter broadcasting · Issue #6301 · PennyLaneAI/pennylane · GitHub

Thanks for the quick reply Catalina. It is great that you agree this is a good feature to add. Looking forward to the progress.

1 Like

Yes absolutely! Thanks for making us aware of this @Kevin_Shen :raised_hands: