I’m trying to run batch circuit executions with pennylane 0.22.1 using jax 0.3.6. Here’s the original that I run, which works fine:
import pennylane as qml
from pennylane import numpy as qnp
import numpy as np
dev = qml.device("default.qubit", wires=1, shots=None)
@qml.qnode(dev)
def circuit(x):
qml.DiagonalQubitUnitary(qnp.exp(1j * x * np.array([-1, 1]) / 2), wires=0)
return qml.expval(qml.PauliX(0))
batch_params = [0, 1]
print([circuit(x) for x in batch_params])
which outputs [tensor(0., requires_grad=True), tensor(0., requires_grad=True)]
as expected. Then I convert this to a circuit compatible with jax:
import jax
import jax.numpy as jnp
dev = qml.device("default.qubit", wires=1, shots=None)
@qml.qnode(dev, interface="jax")
def circuit(x):
qml.DiagonalQubitUnitary(jnp.exp(1j * x * jnp.array([1, -1]) / 2), wires=0)
return qml.expval(qml.PauliX(0))
# Batch execute with jax
vcircuit = jax.vmap(circuit)
batch_params = jnp.asarray([0, 1])
res = vcircuit(batch_params)
This throws the error below. I get the feeling there’s an issue with using the unitary validator on a jax tracer type. But maybe my circuit is written poorly. Any thoughts on a workaround or fix?
Thanks
Traceback:
ConcretizationTypeError Traceback (most recent call last)
Input In [4], in <module>
13 vcircuit = jax.vmap(circuit)
14 batch_params = jnp.asarray([0, 1])
---> 15 res = vcircuit(batch_params)
[... skipping hidden 3 frame]
File ~/projects/envs/xanadu-3.8/lib/python3.8/site-packages/pennylane/qnode.py:578, in QNode.__call__(self, *args, **kwargs)
571 using_custom_cache = (
572 hasattr(cache, "__getitem__")
573 and hasattr(cache, "__setitem__")
574 and hasattr(cache, "__delitem__")
575 )
576 self._tape_cached = using_custom_cache and self.tape.hash in cache
--> 578 res = qml.execute(
579 [self.tape],
580 device=self.device,
581 gradient_fn=self.gradient_fn,
582 interface=self.interface,
583 gradient_kwargs=self.gradient_kwargs,
584 override_shots=override_shots,
585 **self.execute_kwargs,
586 )
588 if autograd.isinstance(res, (tuple, list)) and len(res) == 1:
589 # If a device batch transform was applied, we need to 'unpack'
590 # the returned tuple/list to a float.
(...)
597 # TODO: find a more explicit way of determining that a batch transform
598 # was applied.
600 res = res[0]
File ~/projects/envs/xanadu-3.8/lib/python3.8/site-packages/pennylane/interfaces/batch/__init__.py:342, in execute(tapes, device, gradient_fn, interface, mode, gradient_kwargs, cache, cachesize, max_diff, override_shots, expand_fn, max_expansion, device_batch_transform)
338 return batch_fn(res)
340 if gradient_fn == "backprop" or interface is None:
341 return batch_fn(
--> 342 cache_execute(batch_execute, cache, return_tuple=False, expand_fn=expand_fn)(tapes)
343 )
345 # the default execution function is batch_execute
346 execute_fn = cache_execute(batch_execute, cache, expand_fn=expand_fn)
File ~/projects/envs/xanadu-3.8/lib/python3.8/site-packages/pennylane/interfaces/batch/__init__.py:173, in cache_execute.<locals>.wrapper(tapes, **kwargs)
169 return (res, []) if return_tuple else res
171 else:
172 # execute all unique tapes that do not exist in the cache
--> 173 res = fn(execution_tapes.values(), **kwargs)
175 final_res = []
177 for i, tape in enumerate(tapes):
File ~/projects/envs/xanadu-3.8/lib/python3.8/site-packages/pennylane/interfaces/batch/__init__.py:125, in cache_execute.<locals>.fn(tapes, **kwargs)
123 def fn(tapes, **kwargs): # pylint: disable=function-redefined
124 tapes = [expand_fn(tape) for tape in tapes]
--> 125 return original_fn(tapes, **kwargs)
File /usr/lib/python3.8/contextlib.py:75, in ContextDecorator.__call__.<locals>.inner(*args, **kwds)
72 @wraps(func)
73 def inner(*args, **kwds):
74 with self._recreate_cm():
---> 75 return func(*args, **kwds)
File ~/projects/envs/xanadu-3.8/lib/python3.8/site-packages/pennylane/_qubit_device.py:289, in QubitDevice.batch_execute(self, circuits)
284 for circuit in circuits:
285 # we need to reset the device here, else it will
286 # not start the next computation in the zero state
287 self.reset()
--> 289 res = self.execute(circuit)
290 results.append(res)
292 if self.tracker.active:
File ~/projects/envs/xanadu-3.8/lib/python3.8/site-packages/pennylane/_qubit_device.py:201, in QubitDevice.execute(self, circuit, **kwargs)
198 self.check_validity(circuit.operations, circuit.observables)
200 # apply all circuit operations
--> 201 self.apply(circuit.operations, rotations=circuit.diagonalizing_gates, **kwargs)
203 # generate computational basis samples
204 if self.shots is not None or circuit.is_sampled:
File ~/projects/envs/xanadu-3.8/lib/python3.8/site-packages/pennylane/devices/default_qubit.py:226, in DefaultQubit.apply(self, operations, rotations, **kwargs)
224 self._debugger.snapshots[len(self._debugger.snapshots)] = state_vector
225 else:
--> 226 self._state = self._apply_operation(self._state, operation)
228 # store the pre-rotated state
229 self._pre_rotated_state = self._state
File ~/projects/envs/xanadu-3.8/lib/python3.8/site-packages/pennylane/devices/default_qubit.py:251, in DefaultQubit._apply_operation(self, state, operation)
248 axes = self.wires.indices(wires)
249 return self._apply_ops[operation.base_name](state, axes, inverse=operation.inverse)
--> 251 matrix = self._get_unitary_matrix(operation)
253 if operation in diagonal_in_z_basis:
254 return self._apply_diagonal_unitary(state, matrix, wires)
File ~/projects/envs/xanadu-3.8/lib/python3.8/site-packages/pennylane/devices/default_qubit.py:556, in DefaultQubit._get_unitary_matrix(self, unitary)
545 """Return the matrix representing a unitary operation.
546
547 Args:
(...)
553 a 1D array representing the matrix diagonal.
554 """
555 if unitary in diagonal_in_z_basis:
--> 556 return unitary.get_eigvals()
558 return unitary.get_matrix()
File ~/projects/envs/xanadu-3.8/lib/python3.8/site-packages/pennylane/operation.py:1380, in Operation.get_eigvals(self)
1379 def get_eigvals(self):
-> 1380 op_eigvals = super().get_eigvals()
1382 if self.inverse:
1383 return qml.math.conj(op_eigvals)
File ~/projects/envs/xanadu-3.8/lib/python3.8/site-packages/pennylane/operation.py:753, in Operator.get_eigvals(self)
728 r"""Eigenvalues of the operator in the computational basis (static method).
729
730 If :attr:`diagonalizing_gates` are specified and implement a unitary :math:`U`, the operator
(...)
749 tensor_like: eigenvalues
750 """
752 try:
--> 753 return self.compute_eigvals(*self.parameters, **self.hyperparameters)
754 except EigvalsUndefinedError:
755 # By default, compute the eigenvalues from the matrix representation.
756 # This will raise a NotImplementedError if the matrix is undefined.
757 try:
File ~/projects/envs/xanadu-3.8/lib/python3.8/site-packages/pennylane/ops/qubit/matrix_ops.py:406, in DiagonalQubitUnitary.compute_eigvals(D)
380 r"""Eigenvalues of the operator in the computational basis (static method).
381
382 If :attr:`diagonalizing_gates` are specified and implement a unitary :math:`U`,
(...)
402 tensor([ 1, -1])
403 """
404 D = qml.math.asarray(D)
--> 406 if not qml.math.allclose(D * qml.math.conj(D), qml.math.ones_like(D)):
407 raise ValueError("Operator must be unitary.")
409 return D
[... skipping hidden 1 frame]
File ~/projects/envs/xanadu-3.8/lib/python3.8/site-packages/jax/core.py:1123, in concretization_function_error.<locals>.error(self, arg)
1122 def error(self, arg):
-> 1123 raise ConcretizationTypeError(arg, fname_context)
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[])>with<BatchTrace(level=1/0)> with
val = DeviceArray([ True, True], dtype=bool)
batch_dim = 0
The problem arose with the `bool` function.
This Tracer was created on line /home/ares/projects/envs/xanadu-3.8/lib/python3.8/site-packages/autoray/autoray.py:84 (do)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError