Because there is no native CU3 gate, I implemented it with the qml.ctrl function. However, I was surprised to find that it upscales values to 64 bit precision:
import jax
import jax.numpy as jnp
import pennylane as qml
jax.config.update('jax_platform_name', 'cpu')
jax.config.update("jax_enable_x64", False)
n_qubits = 2
n_layers = 1
dev = qml.device('default.qubit', wires=n_qubits)
def CU3(phi, theta, delta, wires):
qml.ctrl(qml.U3, control=(wires[0]))(phi, theta, delta, wires=wires[1])
@qml.qnode(dev)
def u3_circuit(phis_u3, thetas_u3, deltas_u3, phis_cu3, thetas_cu3, deltas_cu3):
for layer in range(n_layers):
for i in range(n_qubits):
qml.U3(phis_u3[layer, i],
thetas_u3[layer, i],
deltas_u3[layer, i],
wires=i)
return qml.probs(wires=list(range(n_qubits)))
print('U3 result')
u3_result = u3_circuit(*jnp.ones((6, n_layers, n_qubits))).block_until_ready()
print(u3_result)
@qml.qnode(dev)
def cu3_circuit(phis_u3, thetas_u3, deltas_u3, phis_cu3, thetas_cu3, deltas_cu3):
for layer in range(n_layers):
for i in range(n_qubits):
CU3(phis_cu3[layer, i],
thetas_cu3[layer, i],
deltas_cu3[layer, i],
wires=[i, (i+1) % n_qubits])
return qml.probs(wires=list(range(n_qubits)))
print('CU3 result')
cu3_result = cu3_circuit(*jnp.ones((6, n_layers, n_qubits))).block_until_ready()
print(cu3_result)
U3 result
[0.5931328 0.17701836 0.17701836 0.05283049]
CU3 result
[1. 0. 0. 0.]
/home/skylar/quantum/lib/python3.12/site-packages/jax/_src/numpy/array_methods.py:68: UserWarning: Explicitly requested dtype <class 'jax.numpy.complex128'> requested in astype is not available, and will be truncated to dtype complex64. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
return lax_numpy.astype(arr, dtype, copy=copy, device=device)
[repeats a few more times]
I believe this warning is safe to ignore, but I’m curious why qml.ctrl upscales to 64 bit precision when qml.U3 does not.