Hello Pennylane team,
I am trying to write a qml algorithm. I wish to estimate the expectation values of many observables using classical shadow to compute a loss function. I have found the function shadow_expval, which nicely allows me to batch over the observables. However, I failed to apply jax.vmap to batch over data inputs to the circuit. Below is a simplified version of my code and error.
observables=[qml.PauliZ(0),qml.PauliZ(1),qml.PauliZ(2)]
nq=3
dev = qml.device("default.qubit", wires=nq, shots=4096)
@qml.qnode(dev)
def qnode(inputs, weights):
for i in range(nq):
qml.RX(inputs[0], wires=i)
qml.RY(inputs[1], wires=i)
qml.RY(weights[0], wires=0)
qml.RY(weights[1], wires=1)
return qml.shadow_expval(observables)
key = jax.random.PRNGKey(seed=0)
weights=jax.random.uniform(key=key, shape=(2,))
inputs=jnp.ones((16,2))
#qnode(inputs,weights)
jax.vmap(qnode, in_axes=(0,None))(inputs, weights)
jax.pure_callback failed
Traceback (most recent call last):
File "/Users/kevinshen/Documents/evs_venv/lib/python3.9/site-packages/jax/_src/callback.py", line 79, in pure_callback_impl
return tree_util.tree_map(np.asarray, callback(*args))
File "/Users/kevinshen/Documents/evs_venv/lib/python3.9/site-packages/jax/_src/callback.py", line 64, in __call__
return tree_util.tree_leaves(self.callback_func(*args, **kwargs))
File "/Users/kevinshen/Documents/evs_venv/lib/python3.9/site-packages/pennylane/workflow/interfaces/jax_jit.py", line 168, in pure_callback_wrapper
return _to_jax(execute_fn(new_tapes))
File "/Users/kevinshen/Documents/evs_venv/lib/python3.9/site-packages/pennylane/workflow/execution.py", line 212, in inner_execute
results = device.execute(transformed_tapes, execution_config=execution_config)
File "/Users/kevinshen/Documents/evs_venv/lib/python3.9/site-packages/pennylane/devices/modifiers/simulator_tracking.py", line 30, in execute
results = untracked_execute(self, circuits, execution_config)
File "/Users/kevinshen/Documents/evs_venv/lib/python3.9/site-packages/pennylane/devices/modifiers/single_tape_support.py", line 32, in execute
results = batch_execute(self, circuits, execution_config)
File "/Users/kevinshen/Documents/evs_venv/lib/python3.9/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
File "/Users/kevinshen/Documents/evs_venv/lib/python3.9/site-packages/pennylane/devices/default_qubit.py", line 630, in execute
return tuple(
File "/Users/kevinshen/Documents/evs_venv/lib/python3.9/site-packages/pennylane/devices/default_qubit.py", line 631, in <genexpr>
_simulate_wrapper(
File "/Users/kevinshen/Documents/evs_venv/lib/python3.9/site-packages/pennylane/devices/default_qubit.py", line 896, in _simulate_wrapper
return simulate(circuit, **kwargs)
File "/Users/kevinshen/Documents/evs_venv/lib/python3.9/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
File "/Users/kevinshen/Documents/evs_venv/lib/python3.9/site-packages/pennylane/devices/qubit/simulate.py", line 384, in simulate
return measure_final_state(
File "/Users/kevinshen/Documents/evs_venv/lib/python3.9/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
File "/Users/kevinshen/Documents/evs_venv/lib/python3.9/site-packages/pennylane/devices/qubit/simulate.py", line 280, in measure_final_state
results = measure_with_samples(
File "/Users/kevinshen/Documents/evs_venv/lib/python3.9/site-packages/pennylane/devices/qubit/sampling.py", line 255, in measure_with_samples
measure_fn(
File "/Users/kevinshen/Documents/evs_venv/lib/python3.9/site-packages/pennylane/devices/qubit/sampling.py", line 384, in _measure_classical_shadow
return [mp.process_state_with_shots(state, wires, shots.total_shots, rng=rng)]
File "/Users/kevinshen/Documents/evs_venv/lib/python3.9/site-packages/pennylane/measurements/classical_shadow.py", line 544, in process_state_with_shots
bits, recipes = qml.classical_shadow(
File "/Users/kevinshen/Documents/evs_venv/lib/python3.9/site-packages/pennylane/measurements/classical_shadow.py", line 412, in process_state_with_shots
probs = (np.einsum("abc,acb->a", first_qubit_state, obs[:, active_qubit]) + 1) / 2
File "/Users/kevinshen/Documents/evs_venv/lib/python3.9/site-packages/numpy/core/einsumfunc.py", line 1371, in einsum
return c_einsum(*operands, **kwargs)
ValueError: operands could not be broadcast together with remapped shapes [original->remapped]: (4096,16,16)->(4096,16,16) (4096,2,2)->(4096,2,2)
---------------------------------------------------------------------------
XlaRuntimeError Traceback (most recent call last)
Cell In[50], line 5
3 inputs=jnp.ones((16,2))
4 qnode(inputs,weights)
----> 5 jax.vmap(qnode, in_axes=(0,None))(inputs, weights)
[... skipping hidden 3 frame]
File ~/Documents/evs_venv/lib/python3.9/site-packages/pennylane/workflow/qnode.py:1020, in QNode.__call__(self, *args, **kwargs)
1018 if qml.capture.enabled():
1019 return qml.capture.qnode_call(self, *args, **kwargs)
-> 1020 return self._impl_call(*args, **kwargs)
File ~/Documents/evs_venv/lib/python3.9/site-packages/pennylane/workflow/qnode.py:1008, in QNode._impl_call(self, *args, **kwargs)
1005 self._update_gradient_fn(shots=override_shots, tape=self._tape)
1007 try:
-> 1008 res = self._execution_component(args, kwargs, override_shots=override_shots)
1009 finally:
1010 if old_interface == "auto":
File ~/Documents/evs_venv/lib/python3.9/site-packages/pennylane/workflow/qnode.py:957, in QNode._execution_component(self, args, kwargs, override_shots)
951 warnings.filterwarnings(
952 action="ignore",
953 message=r".*argument is deprecated and will be removed in version 0.39.*",
954 category=qml.PennyLaneDeprecationWarning,
955 )
956 # pylint: disable=unexpected-keyword-arg
--> 957 res = qml.execute(
958 (self._tape,),
959 device=self.device,
960 gradient_fn=self.gradient_fn,
961 interface=self.interface,
962 transform_program=full_transform_program,
963 inner_transform=inner_transform_program,
964 config=config,
965 gradient_kwargs=self.gradient_kwargs,
966 override_shots=override_shots,
967 **execute_kwargs,
968 )
969 res = res[0]
971 # convert result to the interface in case the qfunc has no parameters
File ~/Documents/evs_venv/lib/python3.9/site-packages/pennylane/workflow/execution.py:771, in execute(tapes, device, gradient_fn, interface, transform_program, inner_transform, config, grad_on_execution, gradient_kwargs, cache, cachesize, max_diff, override_shots, expand_fn, max_expansion, device_batch_transform, device_vjp, mcm_config)
763 ml_boundary_execute = _get_ml_boundary_execute(
764 interface,
765 config.grad_on_execution,
766 config.use_device_jacobian_product,
767 differentiable=max_diff > 1,
768 )
770 if interface in jpc_interfaces:
--> 771 results = ml_boundary_execute(tapes, execute_fn, jpc, device=device)
772 else:
773 results = ml_boundary_execute(
774 tapes, device, execute_fn, gradient_fn, gradient_kwargs, _n=1, max_diff=max_diff
775 )
File ~/Documents/evs_venv/lib/python3.9/site-packages/pennylane/workflow/interfaces/jax_jit.py:264, in jax_jit_jvp_execute(tapes, execute_fn, jpc, device)
260 raise NotImplementedError("The JAX-JIT interface doesn't support qml.counts.")
262 parameters = tuple(tuple(t.get_parameters(trainable_only=False)) for t in tapes)
--> 264 return _execute_jvp_jit(parameters, _NonPytreeWrapper(tuple(tapes)), execute_fn, jpc, device)
[... skipping hidden 7 frame]
File ~/Documents/evs_venv/lib/python3.9/site-packages/pennylane/workflow/interfaces/jax_jit.py:178, in _execute_wrapper_inner(params, tapes, execute_fn, _, device, is_vjp)
172 else:
173 # first order way of determining native parameter broadcasting support
174 # will be inaccurate when inclusion of broadcast_expand depends on ExecutionConfig values (like adjoint)
175 device_supports_vectorization = (
176 qml.transforms.broadcast_expand not in device.preprocess()[0]
177 )
--> 178 out = jax.pure_callback(
179 pure_callback_wrapper, shape_dtype_structs, params, vectorized=device_supports_vectorization
180 )
181 return out
File ~/Documents/evs_venv/lib/python3.9/site-packages/jax/_src/callback.py:320, in pure_callback(callback, result_shape_dtypes, sharding, vectorized, *args, **kwargs)
317 result_avals = tree_util.tree_map(
318 lambda x: core.ShapedArray(x.shape, x.dtype), result_shape_dtypes)
319 flat_result_avals, out_tree = tree_util.tree_flatten(result_avals)
--> 320 out_flat = pure_callback_p.bind(
321 *flat_args,
322 callback=_FlatCallback(callback, in_tree),
323 result_avals=tuple(flat_result_avals),
324 sharding=sharding,
325 vectorized=vectorized,
326 )
327 return tree_util.tree_unflatten(out_tree, out_flat)
[... skipping hidden 3 frame]
File ~/Documents/evs_venv/lib/python3.9/site-packages/jax/_src/callback.py:137, in pure_callback_batching_rule(args, dims, callback, sharding, vectorized, result_avals)
133 if vectorized:
134 result_avals = tuple(
135 core.unmapped_aval(axis_size, core.no_axis_name, 0, aval) # type: ignore
136 for aval in result_avals)
--> 137 outvals = pure_callback_p.bind(
138 *new_args,
139 callback=callback,
140 sharding=sharding,
141 vectorized=vectorized,
142 result_avals=result_avals,
143 )
144 else:
145 is_batched = [d is not batching.not_mapped for d in dims]
[... skipping hidden 14 frame]
File ~/Documents/evs_venv/lib/python3.9/site-packages/jax/_src/interpreters/pxla.py:1177, in ExecuteReplicated.__call__(self, *args)
1174 if (self.ordered_effects or self.has_unordered_effects
1175 or self.has_host_callbacks):
1176 input_bufs = self._add_tokens_to_inputs(input_bufs)
-> 1177 results = self.xla_executable.execute_sharded(
1178 input_bufs, with_tokens=True
1179 )
1180 result_token_bufs = results.disassemble_prefix_into_single_device_arrays(
1181 len(self.ordered_effects))
1182 sharded_runtime_token = results.consume_token()
XlaRuntimeError: INTERNAL: Generated function failed: CpuCallback error: Traceback (most recent call last):
File "/Users/kevinshen/Documents/evs_venv/lib/python3.9/site-packages/pennylane/devices/default_qubit.py", line 631, in <genexpr>
File "/Users/kevinshen/Documents/evs_venv/lib/python3.9/site-packages/pennylane/devices/default_qubit.py", line 896, in _simulate_wrapper
File "/Users/kevinshen/Documents/evs_venv/lib/python3.9/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
File "/Users/kevinshen/Documents/evs_venv/lib/python3.9/site-packages/pennylane/devices/qubit/simulate.py", line 384, in simulate
File "/Users/kevinshen/Documents/evs_venv/lib/python3.9/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
File "/Users/kevinshen/Documents/evs_venv/lib/python3.9/site-packages/pennylane/devices/qubit/simulate.py", line 280, in measure_final_state
File "/Users/kevinshen/Documents/evs_venv/lib/python3.9/site-packages/pennylane/devices/qubit/sampling.py", line 255, in measure_with_samples
File "/Users/kevinshen/Documents/evs_venv/lib/python3.9/site-packages/pennylane/devices/qubit/sampling.py", line 384, in _measure_classical_shadow
File "/Users/kevinshen/Documents/evs_venv/lib/python3.9/site-packages/pennylane/measurements/classical_shadow.py", line 544, in process_state_with_shots
File "/Users/kevinshen/Documents/evs_venv/lib/python3.9/site-packages/pennylane/measurements/classical_shadow.py", line 412, in process_state_with_shots
File "/Users/kevinshen/Documents/evs_venv/lib/python3.9/site-packages/numpy/core/einsumfunc.py", line 1371, in einsum
ValueError: operands could not be broadcast together with remapped shapes [original->remapped]: (4096,16,16)->(4096,16,16) (4096,2,2)->(4096,2,2)
A more fundamental question is, is jax.vmap beneficial in this case? (I have access to GPUs and saw speedup when I was using vmap on circuits with expval, but I don’t know if it will also be faster for shadow_expval.)