Error using JaxOpt in variational quantum circuits

Hi! I am doing a machine learning project which involves the following quantum circuit:

dev1 = qml.device("default.mixed", wires=2)
@qml.qnode(dev1)
def circuit(params_SU):
    qml.QubitDensityMatrix(identity_matrix/2, 0)
    qml.QubitDensityMatrix(rho, 1)
    qml.SpecialUnitary(params_SU, [0, 1])
    return qml.density_matrix([0])

The cost function is:

#define cost function
expected_result = rho / 2 + identity_matrix / 4
def cost(params_SU):
    trace_B = circuit(params_SU)
    cost_value = np.abs(trace_B - expected_result)
    return float(np.sum(cost_value))

And here is the code for the training:

import jaxopt

#initialise the optimizer
stepsize = 0.001
opt = jaxopt.GradientDescent(cost, stepsize=stepsize, acceleration = True)

# set the number of steps
steps = 100

# set the initial parameter values
opt_state = opt.init_state(params_SU)


cost_list = []
iteration_num = 1
cost_value = cost(params_SU)
cost_recording_step = 10

while True:
    # update the circuit parameters
    params_SU, opt_state= opt.update(params_SU, opt_state)

    if (iteration_num) % cost_recording_step == 0:
        cost_value = cost(params_SU)
        cost_list.append(cost_value)

    if (iteration_num) % 10 == 0:
        print("Cost after step {:5d}: {: .7f}".format(iteration_num, cost(params_SU)))
    
    if iteration_num >= steps:
        break

    iteration_num += 1

However, I kept getting this error that I had no idea to resolve:

---------------------------------------------------------------------------
TracerBoolConversionError                 Traceback (most recent call last)
Cell In[7], line 22
     18 cost_threshold = 0.1
     20 while True:
     21     # update the circuit parameters
---> 22     params_SU, opt_state= opt.update(params_SU, opt_state)
     24     if (iteration_num) % cost_recording_step == 0:
     25         cost_value = cost(params_SU)

    [... skipping hidden 12 frame]

File c:\Users\User\AppData\Local\Programs\Python\Python311\Lib\site-packages\jaxopt\_src\gradient_descent.py:97, in GradientDescent.update(self, params, state, *args, **kwargs)
     82 def update(self,
     83            params: Any,
     84            state: NamedTuple,
     85            *args,
     86            **kwargs) -> base.OptStep:
     87   """Performs one iteration of gradient descent.
     88 
     89   Args:
   (...)
     95     (params, state)
     96   """
---> 97   return super().update(params, state, None, *args, **kwargs)

File c:\Users\User\AppData\Local\Programs\Python\Python311\Lib\site-packages\jaxopt\_src\proximal_gradient.py:305, in ProximalGradient.update(self, params, state, hyperparams_prox, *args, **kwargs)
    293 """Performs one iteration of proximal gradient.
    294 
    295 Args:
   (...)
    302   (params, state)
    303 """
    304 f = self._update_accel if self.acceleration else self._update
--> 305 return f(params, state, hyperparams_prox, args, kwargs)

File c:\Users\User\AppData\Local\Programs\Python\Python311\Lib\site-packages\jaxopt\_src\proximal_gradient.py:264, in ProximalGradient._update_accel(self, x, state, hyperparams_prox, args, kwargs)
    262 t = state.t
    263 stepsize = state.stepsize
--> 264 (y_fun_val, aux), y_fun_grad = self._value_and_grad_with_aux(y, *args,
    265                                                              **kwargs)
    266 next_x, next_stepsize = self._iter(iter_num, y, y_fun_val, y_fun_grad,
    267                                    stepsize, hyperparams_prox, args, kwargs)
    268 next_t = 0.5 * (1 + jnp.sqrt(1 + 4 * t ** 2))

    [... skipping hidden 8 frame]

File c:\Users\User\AppData\Local\Programs\Python\Python311\Lib\site-packages\jaxopt\_src\base.py:71, in _add_aux_to_fun.<locals>.fun_with_aux(*a, **kw)
     70 def fun_with_aux(*a, **kw):
---> 71   return fun(*a, **kw), None

Cell In[6], line 6
      4 def cost(params_SU):
      5     #trace_A = q_obj.ptrace(0)
----> 6     trace_B = circuit(params_SU)
      7     trace_B = np.abs(trace_B)
      8     #cost_value = np.abs(np.array(trace_A + trace_B - 2 * expected_result))
      9     #cost_value = np.abs(np.array(trace_B - expected_result))

File c:\Users\User\AppData\Local\Programs\Python\Python311\Lib\site-packages\pennylane\qnode.py:1027, in QNode.__call__(self, *args, **kwargs)
   1022         full_transform_program._set_all_argnums(
   1023             self, args, kwargs, argnums
   1024         )  # pylint: disable=protected-access
   1026 # pylint: disable=unexpected-keyword-arg
-> 1027 res = qml.execute(
   1028     (self._tape,),
   1029     device=self.device,
   1030     gradient_fn=self.gradient_fn,
   1031     interface=self.interface,
   1032     transform_program=full_transform_program,
   1033     config=config,
   1034     gradient_kwargs=self.gradient_kwargs,
   1035     override_shots=override_shots,
   1036     **self.execute_kwargs,
   1037 )
   1039 res = res[0]
   1041 # convert result to the interface in case the qfunc has no parameters

File c:\Users\User\AppData\Local\Programs\Python\Python311\Lib\site-packages\pennylane\interfaces\execution.py:616, in execute(tapes, device, gradient_fn, interface, transform_program, config, grad_on_execution, gradient_kwargs, cache, cachesize, max_diff, override_shots, expand_fn, max_expansion, device_batch_transform)
    614 # Exiting early if we do not need to deal with an interface boundary
    615 if no_interface_boundary_required:
--> 616     results = inner_execute(tapes)
    617     return post_processing(results)
    619 _grad_on_execution = False

File c:\Users\User\AppData\Local\Programs\Python\Python311\Lib\site-packages\pennylane\interfaces\execution.py:249, in _make_inner_execute.<locals>.inner_execute(tapes, **_)
    247 if numpy_only:
    248     tapes = tuple(qml.transforms.convert_to_numpy_parameters(t) for t in tapes)
--> 249 return cached_device_execution(tapes)

File c:\Users\User\AppData\Local\Programs\Python\Python311\Lib\site-packages\pennylane\interfaces\execution.py:371, in cache_execute.<locals>.wrapper(tapes, **kwargs)
    366         return (res, []) if return_tuple else res
    368 else:
    369     # execute all unique tapes that do not exist in the cache
    370     # convert to list as new device interface returns a tuple
--> 371     res = list(fn(tuple(execution_tapes.values()), **kwargs))
    373 final_res = []
    375 for i, tape in enumerate(tapes):

File c:\Users\User\AppData\Local\Programs\Python\Python311\Lib\contextlib.py:81, in ContextDecorator.__call__.<locals>.inner(*args, **kwds)
     78 @wraps(func)
     79 def inner(*args, **kwds):
     80     with self._recreate_cm():
---> 81         return func(*args, **kwds)

File c:\Users\User\AppData\Local\Programs\Python\Python311\Lib\site-packages\pennylane\_qubit_device.py:460, in QubitDevice.batch_execute(self, circuits)
    455 for circuit in circuits:
    456     # we need to reset the device here, else it will
    457     # not start the next computation in the zero state
    458     self.reset()
--> 460     res = self.execute(circuit)
    461     results.append(res)
    463 if self.tracker.active:

File c:\Users\User\AppData\Local\Programs\Python\Python311\Lib\site-packages\pennylane\devices\default_mixed.py:685, in DefaultMixed.execute(self, circuit, **kwargs)
    683         wires_list.append(m.wires)
    684     self.measured_wires = qml.wires.Wires.all_wires(wires_list)
--> 685 return super().execute(circuit, **kwargs)

File c:\Users\User\AppData\Local\Programs\Python\Python311\Lib\site-packages\pennylane\_qubit_device.py:279, in QubitDevice.execute(self, circuit, **kwargs)
    276 self.check_validity(circuit.operations, circuit.observables)
    278 # apply all circuit operations
--> 279 self.apply(circuit.operations, rotations=self._get_diagonalizing_gates(circuit), **kwargs)
    281 # generate computational basis samples
    282 if self.shots is not None or circuit.is_sampled:

File c:\Users\User\AppData\Local\Programs\Python\Python311\Lib\site-packages\pennylane\devices\default_mixed.py:699, in DefaultMixed.apply(self, operations, rotations, **kwargs)
    693         raise DeviceError(
    694             f"Operation {operation.name} cannot be used after other Operations have already been applied "
    695             f"on a {self.short_name} device."
    696         )
    698 for operation in operations:
--> 699     self._apply_operation(operation)
    701 # store the pre-rotated state
    702 self._pre_rotated_state = self._state

File c:\Users\User\AppData\Local\Programs\Python\Python311\Lib\site-packages\pennylane\devices\default_mixed.py:604, in DefaultMixed._apply_operation(self, operation)
    601     return
    603 if isinstance(operation, QubitDensityMatrix):
--> 604     self._apply_density_matrix(operation.parameters[0], wires)
    605     return
    607 if isinstance(operation, Snapshot):

File c:\Users\User\AppData\Local\Programs\Python\Python311\Lib\site-packages\pennylane\devices\default_mixed.py:540, in DefaultMixed._apply_density_matrix(self, state, device_wires)
    537 if dm_dim != state.shape[0]:
    538     raise ValueError("Density matrix must be of length (2**wires, 2**wires)")
--> 540 if not qnp.allclose(
    541     qnp.trace(qnp.reshape(state, (state_dim, state_dim))), 1.0, atol=tolerance
    542 ):
    543     raise ValueError("Trace of density matrix is not equal one.")
    545 if len(device_wires) == self.num_wires and sorted(device_wires.labels) == list(
    546     device_wires.labels
    547 ):
    548     # Initialize the entire wires with the state

    [... skipping hidden 1 frame]

File c:\Users\User\AppData\Local\Programs\Python\Python311\Lib\site-packages\jax\_src\core.py:1510, in concretization_function_error.<locals>.error(self, arg)
   1509 def error(self, arg):
-> 1510   raise TracerBoolConversionError(arg)

TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function update at c:\Users\User\AppData\Local\Programs\Python\Python311\Lib\site-packages\jaxopt\_src\gradient_descent.py:82 for jit. This value became a tracer due to JAX operations on these lines:

  operation a:c128[] = convert_element_type[new_dtype=complex128 weak_type=False] b
    from line C:\Users\User\AppData\Local\Temp\ipykernel_25644\952782401.py:3:27 (circuit)

  operation a:c128[2,2] = div b c
    from line C:\Users\User\AppData\Local\Temp\ipykernel_25644\952782401.py:3:27 (circuit)

  operation a:bool[] = pjit[
  name=allclose
  jaxpr={ lambda ; b:c128[] c:f64[] d:f64[] e:f64[]. let
      f:bool[] = pjit[
        name=isclose
        jaxpr={ lambda ; g:c128[] h:f64[] i:f64[] j:f64[]. let
            k:c128[] = convert_element_type[
              new_dtype=complex128
              weak_type=False
            ] h
            l:f64[] = convert_element_type[new_dtype=float64 weak_type=False] i
            m:f64[] = convert_element_type[new_dtype=float64 weak_type=False] j
            n:c128[] = sub g k
            o:f64[] = abs n
            p:f64[] = abs k
            q:f64[] = mul l p
            r:f64[] = add m q
            s:bool[] = le o r
            t:bool[] = pjit[
              name=isinf
              jaxpr={ lambda ; u:c128[]. let
                  v:f64[] = real u
                  w:f64[] = imag u
                  x:f64[] = abs v
                  y:bool[] = eq x inf
                  z:f64[] = abs w
                  ba:bool[] = eq z inf
                  bb:bool[] = or y ba
                in (bb,) }
            ] g
            bc:bool[] = pjit[
              name=isinf
              jaxpr={ lambda ; u:c128[]. let
                  v:f64[] = real u
                  w:f64[] = imag u
                  x:f64[] = abs v
                  y:bool[] = eq x inf
                  z:f64[] = abs w
                  ba:bool[] = eq z inf
                  bb:bool[] = or y ba
                in (bb,) }
            ] k
            bd:bool[] = or t bc
            be:bool[] = and t bc
            bf:bool[] = not bd
            bg:bool[] = and s bf
            bh:bool[] = eq g k
            bi:bool[] = and be bh
            bj:bool[] = or bg bi
            bk:bool[] = ne g g
            bl:bool[] = ne k k
            bm:bool[] = or bk bl
            bn:bool[] = not bm
            bo:bool[] = and bj bn
          in (bo,) }
      ] b c e d
      bp:bool[] = reduce_and[axes=()] f
    in (bp,) }
] bq br bs bt
    from line C:\Users\User\AppData\Local\Temp\ipykernel_25644\192657765.py:6:14 (cost)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

It would be very kind if anyone can give offer some suggestions and resolutions. I have been stuck on this for quite a while :frowning: Thank you so much!!! Here is the output of qml.about().

Name: PennyLane
Version: 0.33.1
Summary: PennyLane is a Python quantum machine learning library by Xanadu Inc.
Home-page: https://github.com/PennyLaneAI/pennylane
Author: 
Author-email: 
License: Apache License 2.0
Location: c:\Users\86986\AppData\Local\Programs\Python\Python311\Lib\site-packages
Requires: appdirs, autograd, autoray, cachetools, networkx, numpy, pennylane-lightning, requests, rustworkx, scipy, semantic-version, toml, typing-extensions
Required-by: PennyLane-Lightning

Platform info:           Windows-10-10.0.22631-SP0
Python version:          3.11.4
Numpy version:           1.26.4
Scipy version:           1.11.2
Installed devices:
- default.gaussian (PennyLane-0.33.1)
- default.mixed (PennyLane-0.33.1)
- default.qubit (PennyLane-0.33.1)
- default.qubit.autograd (PennyLane-0.33.1)
- default.qubit.jax (PennyLane-0.33.1)
- default.qubit.legacy (PennyLane-0.33.1)
- default.qubit.tf (PennyLane-0.33.1)
- default.qubit.torch (PennyLane-0.33.1)
- default.qutrit (PennyLane-0.33.1)
- null.qubit (PennyLane-0.33.1)
- lightning.qubit (PennyLane-Lightning-0.33.1)

Hey @Ji_Ching,

I’ll have to look into this and it might take a couple days at least. In the meantime, does updating your PennyLane version do anything? Currently we’re at v0.35.1. You can update via pip install --upgrade pennylane.