I’m wondering whether jitting the evaluation of the quantum fisher info (i.e. the metric tensor) is supported? The answer might already be implied by this forum post but it’s not clear to me whether it’s meant to be supported and below is a bug, or if the functionality just isn’t there yet. Below I have a small example of a cost function which is the trace of the metric tensor, and I want to jit the qfim_trace
function. This succeeds when use_jax == False
. There appears to be a tracer issue in the JAX compilation procedure.
import jax
import jax.numpy as jnp
import pennylane as qml
from pennylane import numpy as pnp
from functools import partial
num_params = 16
dev = qml.device("default.qubit", wires=1+1)
@qml.qnode(dev)
def circuit(x):
for ia in range(num_params):
qml.RX(x[ia], wires=0)
return qml.expval(qml.Z(0))
def qfim_trace(x, use_jax=False):
qfim = qml.metric_tensor(circuit)(x)
if use_jax:
return jnp.trace(qfim)
else:
return pnp.trace(qfim)
use_jax = True
temp = partial(qfim_trace, use_jax=use_jax)
if use_jax:
x = jnp.arange(float(num_params))
cost_fn = jax.jit(temp)
else:
x = pnp.arange(float(num_params))
cost_fn = temp
cost_fn(x) #<-- fails when use_jax == True
The error trace is below. For some reason the forum won’t let my post contain the word “c all”, so I replaced it with “<>” in the error trace.
---------------------------------------------------------------------------
TracerBoolConversionError Traceback (most recent <> last)
Cell In[37], line 31
28 x = pnp.arange(float(num_params))
29 cost_fn = temp
---> 31 cost_fn(x)
[... skipping hidden 11 frame]
Cell In[37], line 16
15 def qfim_trace(x, use_jax=False):
---> 16 qfim = qml.metric_tensor(circuit)(x)
17 if use_jax:
18 return jnp.trace(qfim)
File ~/miniconda3/envs/lanl/lib/python3.12/site-packages/pennylane/workflow/qnode.py:1020, in QNode.__<>__(self, *args, **kwargs)
1018 if qml.capture.enabled():
1019 return qml.capture.qnode_<>(self, *args, **kwargs)
-> 1020 return self._impl_<>(*args, **kwargs)
File ~/miniconda3/envs/lanl/lib/python3.12/site-packages/pennylane/workflow/qnode.py:1008, in QNode._impl_<>(self, *args, **kwargs)
1005 self._update_gradient_fn(shots=override_shots, tape=self._tape)
1007 try:
-> 1008 res = self._execution_component(args, kwargs, override_shots=override_shots)
1009 finally:
1010 if old_interface == "auto":
File ~/miniconda3/envs/lanl/lib/python3.12/site-packages/pennylane/workflow/qnode.py:957, in QNode._execution_component(self, args, kwargs, override_shots)
951 warnings.filterwarnings(
952 action="ignore",
953 message=r".*argument is deprecated and will be removed in version 0.39.*",
954 category=qml.PennyLaneDeprecationWarning,
955 )
956 # pylint: disable=unexpected-keyword-arg
--> 957 res = qml.execute(
958 (self._tape,),
959 device=self.device,
960 gradient_fn=self.gradient_fn,
961 interface=self.interface,
962 transform_program=full_transform_program,
963 inner_transform=inner_transform_program,
964 config=config,
965 gradient_kwargs=self.gradient_kwargs,
966 override_shots=override_shots,
967 **execute_kwargs,
968 )
969 res = res[0]
971 # convert result to the interface in case the qfunc has no parameters
File ~/miniconda3/envs/lanl/lib/python3.12/site-packages/pennylane/workflow/execution.py:661, in execute(tapes, device, gradient_fn, interface, transform_program, inner_transform, config, grad_on_execution, gradient_kwargs, cache, cachesize, max_diff, override_shots, expand_fn, max_expansion, device_batch_transform, device_vjp, mcm_config)
659 if no_interface_boundary_required:
660 results = inner_execute(tapes)
--> 661 return post_processing(results)
663 if (
664 device_vjp
665 and getattr(device, "short_name", "") in ("lightning.gpu", "lightning.kokkos")
666 and interface in jpc_interfaces
667 ): # pragma: no cover
668 if INTERFACE_MAP[interface] == "jax" and "use_device_state" in gradient_kwargs:
File ~/miniconda3/envs/lanl/lib/python3.12/site-packages/pennylane/transforms/core/transform_program.py:84, in _apply_postprocessing_stack(results, postprocessing_stack)
61 """Applies the postprocessing and cotransform postprocessing functions in a Last-In-First-Out LIFO manner.
62
63 Args:
(...)
81
82 """
83 for postprocessing in reversed(postprocessing_stack):
---> 84 results = postprocessing(results)
85 return results
File ~/miniconda3/envs/lanl/lib/python3.12/site-packages/pennylane/transforms/core/transform_program.py:54, in _batch_postprocessing(results, individual_fns, slices)
28 def _batch_postprocessing(
29 results: ResultBatch, individual_fns: list[PostprocessingFn], slices: list[slice]
30 ) -> ResultBatch:
31 """Broadcast individual post processing functions onto their respective tapes.
32
33 Args:
(...)
52
53 """
---> 54 return tuple(fn(results[sl]) for fn, sl in zip(individual_fns, slices))
File ~/miniconda3/envs/lanl/lib/python3.12/site-packages/pennylane/transforms/core/transform_program.py:54, in <genexpr>(.0)
28 def _batch_postprocessing(
29 results: ResultBatch, individual_fns: list[PostprocessingFn], slices: list[slice]
30 ) -> ResultBatch:
31 """Broadcast individual post processing functions onto their respective tapes.
32
33 Args:
(...)
52
53 """
---> 54 return tuple(fn(results[sl]) for fn, sl in zip(individual_fns, slices))
File ~/miniconda3/envs/lanl/lib/python3.12/site-packages/pennylane/gradients/metric_tensor.py:61, in _contract_metric_tensor_with_cjac(mt, cjac, tape)
57 return metric_tensors[0] if len(metric_tensors) == 1 else metric_tensors
59 is_square = cjac.shape == (1,) or (cjac.ndim == 2 and cjac.shape[0] == cjac.shape[1])
---> 61 if is_square and qml.math.allclose(cjac, qml.numpy.eye(cjac.shape[0])):
62 # Classical Jacobian is the identity. No classical processing
63 # is present inside the QNode.
64 return mt
65 mt_cjac = qml.math.tensordot(mt, cjac, axes=[[-1], [0]])
[... skipping hidden 1 frame]
File ~/miniconda3/envs/lanl/lib/python3.12/site-packages/jax/_src/core.py:1475, in concretization_function_error.<locals>.error(self, arg)
1474 def error(self, arg):
-> 1475 raise TracerBoolConversionError(arg)
TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function qfim_trace at /var/folders/2q/925k9byj00g101q99sk75gxh0000gn/T/ipykernel_36016/2968992411.py:15 for jit. This value became a tracer due to JAX operations on these lines:
operation a:i32[16,16] = add b c
from line /var/folders/2q/925k9byj00g101q99sk75gxh0000gn/T/ipykernel_36016/2968992411.py:16:11 (qfim_trace)
operation a:f64[16,16] = pad[padding_config=((0, 0, 0), (15, 0, 0))] b c
from line /var/folders/2q/925k9byj00g101q99sk75gxh0000gn/T/ipykernel_36016/2968992411.py:12:15 (circuit)
operation a:f64[16,16] = pad[padding_config=((0, 0, 0), (14, 1, 0))] b c
from line /var/folders/2q/925k9byj00g101q99sk75gxh0000gn/T/ipykernel_36016/2968992411.py:12:15 (circuit)
operation a:f64[16,16] = pad[padding_config=((0, 0, 0), (13, 2, 0))] b c
from line /var/folders/2q/925k9byj00g101q99sk75gxh0000gn/T/ipykernel_36016/2968992411.py:12:15 (circuit)
operation a:f64[16,16] = pad[padding_config=((0, 0, 0), (12, 3, 0))] b c
from line /var/folders/2q/925k9byj00g101q99sk75gxh0000gn/T/ipykernel_36016/2968992411.py:12:15 (circuit)
(Additional originating lines are not shown.)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError```
qml.about()
Name: PennyLane Version: 0.38.0 Summary: PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network. Home-page: https://github.com/PennyLaneAI/pennylane Author: Author-email: License: Apache License 2.0 Location: [/Users/joey/miniconda3/envs/lanl/lib/python3.12/site-packages](https://file+.vscode-resource.vscode-cdn.net/Users/joey/miniconda3/envs/lanl/lib/python3.12/site-packages) Requires: appdirs, autograd, autoray, cachetools, networkx, numpy, packaging, pennylane-lightning, requests, rustworkx, scipy, toml, typing-extensions
Required-by: PennyLane-Catalyst, PennyLane_Lightning
Platform info: macOS-14.4.1-arm64-arm-64bit
Python version: 3.12.6
Numpy version: 1.26.4
Scipy version: 1.12.0
Installed devices: - nvidia.custatevec (PennyLane-Catalyst-0.8.1) - nvidia.cutensornet (PennyLane-Catalyst-0.8.1) - oqc.cloud (PennyLane-Catalyst-0.8.1) - softwareq.qpp (PennyLane-Catalyst-0.8.1) - lightning.qubit (PennyLane_Lightning-0.38.0) - default.clifford (PennyLane-0.38.0) - default.gaussian (PennyLane-0.38.0) - default.mixed (PennyLane-0.38.0) - default.qubit (PennyLane-0.38.0) - default.qubit.autograd (PennyLane-0.38.0) - default.qubit.jax (PennyLane-0.38.0) - default.qubit.legacy (PennyLane-0.38.0) - default.qubit.tf (PennyLane-0.38.0) - default.qubit.torch (PennyLane-0.38.0)