Using JAX with metric_tensor

I’m wondering whether jitting the evaluation of the quantum fisher info (i.e. the metric tensor) is supported? The answer might already be implied by this forum post but it’s not clear to me whether it’s meant to be supported and below is a bug, or if the functionality just isn’t there yet. Below I have a small example of a cost function which is the trace of the metric tensor, and I want to jit the qfim_trace function. This succeeds when use_jax == False. There appears to be a tracer issue in the JAX compilation procedure.

import jax
import jax.numpy as jnp
import pennylane as qml
from pennylane import numpy as pnp
from functools import partial

num_params = 16
dev = qml.device("default.qubit", wires=1+1)
@qml.qnode(dev)
def circuit(x):
    for ia in range(num_params):
        qml.RX(x[ia], wires=0)
    return qml.expval(qml.Z(0))

def qfim_trace(x, use_jax=False):
    qfim = qml.metric_tensor(circuit)(x)
    if use_jax:
        return jnp.trace(qfim)
    else:
        return pnp.trace(qfim)
    
use_jax = True
temp = partial(qfim_trace, use_jax=use_jax)
if use_jax:
    x = jnp.arange(float(num_params))
    cost_fn = jax.jit(temp)
else:
    x = pnp.arange(float(num_params))
    cost_fn = temp

cost_fn(x) #<-- fails when use_jax == True

The error trace is below. For some reason the forum won’t let my post contain the word “c all”, so I replaced it with “<>” in the error trace.

---------------------------------------------------------------------------
TracerBoolConversionError                 Traceback (most recent <> last)
Cell In[37], line 31
     28     x = pnp.arange(float(num_params))
     29     cost_fn = temp
---> 31 cost_fn(x)

    [... skipping hidden 11 frame]

Cell In[37], line 16
     15 def qfim_trace(x, use_jax=False):
---> 16     qfim = qml.metric_tensor(circuit)(x)
     17     if use_jax:
     18         return jnp.trace(qfim)

File ~/miniconda3/envs/lanl/lib/python3.12/site-packages/pennylane/workflow/qnode.py:1020, in QNode.__<>__(self, *args, **kwargs)
   1018 if qml.capture.enabled():
   1019     return qml.capture.qnode_<>(self, *args, **kwargs)
-> 1020 return self._impl_<>(*args, **kwargs)

File ~/miniconda3/envs/lanl/lib/python3.12/site-packages/pennylane/workflow/qnode.py:1008, in QNode._impl_<>(self, *args, **kwargs)
   1005 self._update_gradient_fn(shots=override_shots, tape=self._tape)
   1007 try:
-> 1008     res = self._execution_component(args, kwargs, override_shots=override_shots)
   1009 finally:
   1010     if old_interface == "auto":

File ~/miniconda3/envs/lanl/lib/python3.12/site-packages/pennylane/workflow/qnode.py:957, in QNode._execution_component(self, args, kwargs, override_shots)
    951     warnings.filterwarnings(
    952         action="ignore",
    953         message=r".*argument is deprecated and will be removed in version 0.39.*",
    954         category=qml.PennyLaneDeprecationWarning,
    955     )
    956     # pylint: disable=unexpected-keyword-arg
--> 957     res = qml.execute(
    958         (self._tape,),
    959         device=self.device,
    960         gradient_fn=self.gradient_fn,
    961         interface=self.interface,
    962         transform_program=full_transform_program,
    963         inner_transform=inner_transform_program,
    964         config=config,
    965         gradient_kwargs=self.gradient_kwargs,
    966         override_shots=override_shots,
    967         **execute_kwargs,
    968     )
    969 res = res[0]
    971 # convert result to the interface in case the qfunc has no parameters

File ~/miniconda3/envs/lanl/lib/python3.12/site-packages/pennylane/workflow/execution.py:661, in execute(tapes, device, gradient_fn, interface, transform_program, inner_transform, config, grad_on_execution, gradient_kwargs, cache, cachesize, max_diff, override_shots, expand_fn, max_expansion, device_batch_transform, device_vjp, mcm_config)
    659 if no_interface_boundary_required:
    660     results = inner_execute(tapes)
--> 661     return post_processing(results)
    663 if (
    664     device_vjp
    665     and getattr(device, "short_name", "") in ("lightning.gpu", "lightning.kokkos")
    666     and interface in jpc_interfaces
    667 ):  # pragma: no cover
    668     if INTERFACE_MAP[interface] == "jax" and "use_device_state" in gradient_kwargs:

File ~/miniconda3/envs/lanl/lib/python3.12/site-packages/pennylane/transforms/core/transform_program.py:84, in _apply_postprocessing_stack(results, postprocessing_stack)
     61 """Applies the postprocessing and cotransform postprocessing functions in a Last-In-First-Out LIFO manner.
     62 
     63 Args:
   (...)
     81 
     82 """
     83 for postprocessing in reversed(postprocessing_stack):
---> 84     results = postprocessing(results)
     85 return results

File ~/miniconda3/envs/lanl/lib/python3.12/site-packages/pennylane/transforms/core/transform_program.py:54, in _batch_postprocessing(results, individual_fns, slices)
     28 def _batch_postprocessing(
     29     results: ResultBatch, individual_fns: list[PostprocessingFn], slices: list[slice]
     30 ) -> ResultBatch:
     31     """Broadcast individual post processing functions onto their respective tapes.
     32 
     33     Args:
   (...)
     52 
     53     """
---> 54     return tuple(fn(results[sl]) for fn, sl in zip(individual_fns, slices))

File ~/miniconda3/envs/lanl/lib/python3.12/site-packages/pennylane/transforms/core/transform_program.py:54, in <genexpr>(.0)
     28 def _batch_postprocessing(
     29     results: ResultBatch, individual_fns: list[PostprocessingFn], slices: list[slice]
     30 ) -> ResultBatch:
     31     """Broadcast individual post processing functions onto their respective tapes.
     32 
     33     Args:
   (...)
     52 
     53     """
---> 54     return tuple(fn(results[sl]) for fn, sl in zip(individual_fns, slices))

File ~/miniconda3/envs/lanl/lib/python3.12/site-packages/pennylane/gradients/metric_tensor.py:61, in _contract_metric_tensor_with_cjac(mt, cjac, tape)
     57     return metric_tensors[0] if len(metric_tensors) == 1 else metric_tensors
     59 is_square = cjac.shape == (1,) or (cjac.ndim == 2 and cjac.shape[0] == cjac.shape[1])
---> 61 if is_square and qml.math.allclose(cjac, qml.numpy.eye(cjac.shape[0])):
     62     # Classical Jacobian is the identity. No classical processing
     63     # is present inside the QNode.
     64     return mt
     65 mt_cjac = qml.math.tensordot(mt, cjac, axes=[[-1], [0]])

    [... skipping hidden 1 frame]

File ~/miniconda3/envs/lanl/lib/python3.12/site-packages/jax/_src/core.py:1475, in concretization_function_error.<locals>.error(self, arg)
   1474 def error(self, arg):
-> 1475   raise TracerBoolConversionError(arg)

TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function qfim_trace at /var/folders/2q/925k9byj00g101q99sk75gxh0000gn/T/ipykernel_36016/2968992411.py:15 for jit. This value became a tracer due to JAX operations on these lines:

  operation a:i32[16,16] = add b c
    from line /var/folders/2q/925k9byj00g101q99sk75gxh0000gn/T/ipykernel_36016/2968992411.py:16:11 (qfim_trace)

  operation a:f64[16,16] = pad[padding_config=((0, 0, 0), (15, 0, 0))] b c
    from line /var/folders/2q/925k9byj00g101q99sk75gxh0000gn/T/ipykernel_36016/2968992411.py:12:15 (circuit)

  operation a:f64[16,16] = pad[padding_config=((0, 0, 0), (14, 1, 0))] b c
    from line /var/folders/2q/925k9byj00g101q99sk75gxh0000gn/T/ipykernel_36016/2968992411.py:12:15 (circuit)

  operation a:f64[16,16] = pad[padding_config=((0, 0, 0), (13, 2, 0))] b c
    from line /var/folders/2q/925k9byj00g101q99sk75gxh0000gn/T/ipykernel_36016/2968992411.py:12:15 (circuit)

  operation a:f64[16,16] = pad[padding_config=((0, 0, 0), (12, 3, 0))] b c
    from line /var/folders/2q/925k9byj00g101q99sk75gxh0000gn/T/ipykernel_36016/2968992411.py:12:15 (circuit)

(Additional originating lines are not shown.)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError```

qml.about()

Name: PennyLane Version: 0.38.0 Summary: PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network. Home-page: https://github.com/PennyLaneAI/pennylane Author: Author-email: License: Apache License 2.0 Location: [/Users/joey/miniconda3/envs/lanl/lib/python3.12/site-packages](https://file+.vscode-resource.vscode-cdn.net/Users/joey/miniconda3/envs/lanl/lib/python3.12/site-packages) Requires: appdirs, autograd, autoray, cachetools, networkx, numpy, packaging, pennylane-lightning, requests, rustworkx, scipy, toml, typing-extensions 

Required-by: PennyLane-Catalyst, PennyLane_Lightning 
Platform info: macOS-14.4.1-arm64-arm-64bit 
Python version: 3.12.6 
Numpy version: 1.26.4 
Scipy version: 1.12.0 

Installed devices: - nvidia.custatevec (PennyLane-Catalyst-0.8.1) - nvidia.cutensornet (PennyLane-Catalyst-0.8.1) - oqc.cloud (PennyLane-Catalyst-0.8.1) - softwareq.qpp (PennyLane-Catalyst-0.8.1) - lightning.qubit (PennyLane_Lightning-0.38.0) - default.clifford (PennyLane-0.38.0) - default.gaussian (PennyLane-0.38.0) - default.mixed (PennyLane-0.38.0) - default.qubit (PennyLane-0.38.0) - default.qubit.autograd (PennyLane-0.38.0) - default.qubit.jax (PennyLane-0.38.0) - default.qubit.legacy (PennyLane-0.38.0) - default.qubit.tf (PennyLane-0.38.0) - default.qubit.torch (PennyLane-0.38.0)

Hi @joeybarreto , sorry with the issue about using the word “call”. We’ve updated our settings so that you can use it now.

Let me check and get back to you. The issue doesn’t look related to the other post you linked.

1 Like

Hi @joeybarreto ,

I wanted to let you know that I can replicate your error and this looks like a bug to me. We’re checking with the team and will keep you posted. Thanks for asking your question here and making us aware of this issue!

Hi @joeybarreto

Thanks for reporting this incompatibility.
It’s very easy to fix (we are performing a conditional check that depends on the value of a dynamical object, which is not JIT compatible. Can just skip doing this when JITting).
Here is the fix PR.
Let me know whether this fixes all issues with this!
Thanks :slight_smile:

Happy jitting!

1 Like