### Expected behavior
A QNode transformed with `mitigate_with_zne` should accep…t parameters with a leading dimension.
### Actual behavior
A QNode transformed with `mitigate_with_zne` does not accept parameters with a leading dimension. Adding `qml.transforms.broadcast_expand` fixes it.
### Additional information
Reported here: https://discuss.pennylane.ai/t/error-with-keraslayer-and-amplitude-embedding/3924
Please update this forum post when there is progress made.
### Source code
```shell
import pennylane as qml
from pennylane import numpy as np
from pennylane.transforms import richardson_extrapolate, fold_global
nqbits=4
batch_size = 2
dev_ideal = qml.device('default.mixed', wires=nqbits)
dev_mixed = qml.transforms.insert(dev_ideal, qml.DepolarizingChannel, 0.1) # Adding noise
#@qml.transforms.broadcast_expand
@qml.qnode(dev_mixed)
def original_qnode(inputs):
qml.AmplitudeEmbedding(features=inputs, wires=range(nqbits), normalize=True)
return [qml.expval(qml.PauliZ(wires=i)) for i in range(nqbits)]
mitigated_qnode = qml.transforms.mitigate_with_zne(original_qnode, [1,2,3], fold_global, richardson_extrapolate)
inputs = np.random.uniform(0, 1, size=(batch_size, 2**nqbits))
mitigated_qnode(inputs)
```
### Tracebacks
```shell
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[16], line 20
17 mitigated_qnode = qml.transforms.mitigate_with_zne(original_qnode, [1,2,3], fold_global, richardson_extrapolate)
19 inputs = np.random.uniform(0, 1, size=(batch_size, 2**nqbits))
---> 20 mitigated_qnode(inputs)
File ~/.virtualenvs/pennylane-torch/lib/python3.11/site-packages/pennylane/qnode.py:1039, in QNode.__call__(self, *args, **kwargs)
1034 full_transform_program._set_all_argnums(
1035 self, args, kwargs, argnums
1036 ) # pylint: disable=protected-access
1038 # pylint: disable=unexpected-keyword-arg
-> 1039 res = qml.execute(
1040 (self._tape,),
1041 device=self.device,
1042 gradient_fn=self.gradient_fn,
1043 interface=self.interface,
1044 transform_program=full_transform_program,
1045 config=config,
1046 gradient_kwargs=self.gradient_kwargs,
1047 override_shots=override_shots,
1048 **self.execute_kwargs,
1049 )
1051 res = res[0]
1053 # convert result to the interface in case the qfunc has no parameters
File ~/.virtualenvs/pennylane-torch/lib/python3.11/site-packages/pennylane/interfaces/execution.py:649, in execute(tapes, device, gradient_fn, interface, transform_program, config, grad_on_execution, gradient_kwargs, cache, cachesize, max_diff, override_shots, expand_fn, max_expansion, device_batch_transform, device_vjp)
647 if no_interface_boundary_required:
648 results = inner_execute(tapes)
--> 649 return post_processing(results)
651 _grad_on_execution = False
653 if config.use_device_jacobian_product and interface in jpc_interfaces:
File ~/.virtualenvs/pennylane-torch/lib/python3.11/site-packages/pennylane/interfaces/execution.py:641, in execute.<locals>.post_processing(results)
640 def post_processing(results):
--> 641 return program_post_processing(program_pre_processing(results))
File ~/.virtualenvs/pennylane-torch/lib/python3.11/site-packages/pennylane/transforms/core/transform_program.py:86, in _apply_postprocessing_stack(results, postprocessing_stack)
63 """Applies the postprocessing and cotransform postprocessing functions in a Last-In-First-Out LIFO manner.
64
65 Args:
(...)
83
84 """
85 for postprocessing in reversed(postprocessing_stack):
---> 86 results = postprocessing(results)
87 return results
File ~/.virtualenvs/pennylane-torch/lib/python3.11/site-packages/pennylane/transforms/core/transform_program.py:56, in _batch_postprocessing(results, individual_fns, slices)
30 def _batch_postprocessing(
31 results: ResultBatch, individual_fns: List[PostProcessingFn], slices: List[slice]
32 ) -> ResultBatch:
33 """Broadcast individual post processing functions onto their respective tapes.
34
35 Args:
(...)
54
55 """
---> 56 return tuple(fn(results[sl]) for fn, sl in zip(individual_fns, slices))
File ~/.virtualenvs/pennylane-torch/lib/python3.11/site-packages/pennylane/transforms/core/transform_program.py:56, in <genexpr>(.0)
30 def _batch_postprocessing(
31 results: ResultBatch, individual_fns: List[PostProcessingFn], slices: List[slice]
32 ) -> ResultBatch:
33 """Broadcast individual post processing functions onto their respective tapes.
34
35 Args:
(...)
54
55 """
---> 56 return tuple(fn(results[sl]) for fn, sl in zip(individual_fns, slices))
File ~/.virtualenvs/pennylane-torch/lib/python3.11/site-packages/pennylane/transforms/mitigate.py:541, in mitigate_with_zne.<locals>.processing_fn(results)
536 for i in range(0, len(results), reps_per_factor):
537 # The stacking ensures the right interface is used
538 # averaging over axis=0 is critical because the qnode may have multiple outputs
539 results_flattened.append(mean(qml.math.stack(results[i : i + reps_per_factor]), axis=0))
--> 541 extrapolated = extrapolate(scale_factors, results_flattened, **extrapolate_kwargs)
543 extrapolated = extrapolated[0] if shape(extrapolated) == (1,) else extrapolated
545 # unstack the results in the case of multiple measurements
File ~/.virtualenvs/pennylane-torch/lib/python3.11/site-packages/pennylane/transforms/mitigate.py:320, in richardson_extrapolate(x, y)
297 def richardson_extrapolate(x, y):
298 r"""Polynomial fit where the degree of the polynomial is fixed to being equal to the length of ``x``.
299
300 In a nutshell, this function is calling :func:`~.pennylane.transforms.poly_extrapolate` with ``order = len(x)-1``.
(...)
318
319 """
--> 320 return poly_extrapolate(x, y, len(x) - 1)
File ~/.virtualenvs/pennylane-torch/lib/python3.11/site-packages/pennylane/transforms/mitigate.py:293, in poly_extrapolate(x, y, order)
269 def poly_extrapolate(x, y, order):
270 r"""Extrapolator to :math:`f(0)` for polynomial fit.
271
272 The polynomial is defined as ``f(x) = p[0] * x**deg + p[1] * x**(deg-1) + ... + p[deg]`` such that ``deg = order + 1``.
(...)
291
292 """
--> 293 coeff = _polyfit(x, y, order)
294 return coeff[-1]
File ~/.virtualenvs/pennylane-torch/lib/python3.11/site-packages/pennylane/transforms/mitigate.py:264, in _polyfit(x, y, order)
262 c = qml.math.linalg.pinv(qml.math.transpose(X) @ X)
263 c = c @ qml.math.transpose(X)
--> 264 c = qml.math.dot(c, y)
265 c = qml.math.transpose(qml.math.transpose(c) / scale)
266 return c
File ~/.virtualenvs/pennylane-torch/lib/python3.11/site-packages/pennylane/math/multi_dispatch.py:151, in multi_dispatch.<locals>.decorator.<locals>.wrapper(*args, **kwargs)
148 interface = interface or get_interface(*dispatch_args)
149 kwargs["like"] = interface
--> 151 return fn(*args, **kwargs)
File ~/.virtualenvs/pennylane-torch/lib/python3.11/site-packages/pennylane/math/multi_dispatch.py:358, in dot(tensor1, tensor2, like)
354 return x @ y
356 return np.tensordot(x, y, axes=[[-1], [-2]], like=like)
--> 358 return np.dot(x, y, like=like)
File ~/.virtualenvs/pennylane-torch/lib/python3.11/site-packages/autoray/autoray.py:80, in do(fn, like, *args, **kwargs)
31 """Do function named ``fn`` on ``(*args, **kwargs)``, peforming single
32 dispatch to retrieve ``fn`` based on whichever library defines the class of
33 the ``args[0]``, or the ``like`` keyword argument if specified.
(...)
77 <tf.Tensor: id=91, shape=(3, 3), dtype=float32>
78 """
79 backend = choose_backend(fn, *args, like=like, **kwargs)
---> 80 return get_lib_fn(backend, fn)(*args, **kwargs)
ValueError: shapes (3,3) and (3,4,2) not aligned: 3 (dim 1) != 4 (dim 1)
```
### System information
```shell
Name: PennyLane
Version: 0.34.0
Summary: PennyLane is a Python quantum machine learning library by Xanadu Inc.
Home-page: https://github.com/PennyLaneAI/pennylane
Author:
Author-email:
License: Apache License 2.0
Location: /Users/isaac/.virtualenvs/pennylane-torch/lib/python3.11/site-packages
Requires: appdirs, autograd, autoray, cachetools, networkx, numpy, pennylane-lightning, requests, rustworkx, scipy, semantic-version, toml, typing-extensions
Required-by: PennyLane-Lightning
Platform info: macOS-14.2.1-arm64-arm-64bit
Python version: 3.11.6
Numpy version: 1.26.2
Scipy version: 1.11.4
Installed devices:
- lightning.qubit (PennyLane-Lightning-0.34.0)
- default.gaussian (PennyLane-0.34.0)
- default.mixed (PennyLane-0.34.0)
- default.qubit (PennyLane-0.34.0)
- default.qubit.autograd (PennyLane-0.34.0)
- default.qubit.jax (PennyLane-0.34.0)
- default.qubit.legacy (PennyLane-0.34.0)
- default.qubit.tf (PennyLane-0.34.0)
- default.qubit.torch (PennyLane-0.34.0)
- default.qutrit (PennyLane-0.34.0)
- null.qubit (PennyLane-0.34.0)
```
### Existing GitHub issues
- [X] I have searched existing GitHub issues to make sure the issue does not already exist.