Why does qml.ctrl upscale to 64 bit precision?

Because there is no native CU3 gate, I implemented it with the qml.ctrl function. However, I was surprised to find that it upscales values to 64 bit precision:

import jax
import jax.numpy as jnp
import pennylane as qml

jax.config.update('jax_platform_name', 'cpu')
jax.config.update("jax_enable_x64", False)

n_qubits = 2
n_layers = 1
dev = qml.device('default.qubit', wires=n_qubits)

def CU3(phi, theta, delta, wires):
    qml.ctrl(qml.U3, control=(wires[0]))(phi, theta, delta, wires=wires[1])

@qml.qnode(dev)
def u3_circuit(phis_u3, thetas_u3, deltas_u3, phis_cu3, thetas_cu3, deltas_cu3):
    for layer in range(n_layers):
        for i in range(n_qubits):
            qml.U3(phis_u3[layer, i],
                   thetas_u3[layer, i],
                   deltas_u3[layer, i],
                   wires=i)
    return qml.probs(wires=list(range(n_qubits)))

print('U3 result')
u3_result = u3_circuit(*jnp.ones((6, n_layers, n_qubits))).block_until_ready()
print(u3_result)

@qml.qnode(dev)
def cu3_circuit(phis_u3, thetas_u3, deltas_u3, phis_cu3, thetas_cu3, deltas_cu3):
    for layer in range(n_layers):
        for i in range(n_qubits):
            CU3(phis_cu3[layer, i],
                thetas_cu3[layer, i],
                deltas_cu3[layer, i],
                wires=[i, (i+1) % n_qubits])
            
    return qml.probs(wires=list(range(n_qubits)))

print('CU3 result')
cu3_result = cu3_circuit(*jnp.ones((6, n_layers, n_qubits))).block_until_ready()
print(cu3_result)
U3 result
[0.5931328  0.17701836 0.17701836 0.05283049]
CU3 result
[1. 0. 0. 0.]

/home/skylar/quantum/lib/python3.12/site-packages/jax/_src/numpy/array_methods.py:68: UserWarning: Explicitly requested dtype <class 'jax.numpy.complex128'> requested in astype is not available, and will be truncated to dtype complex64. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  return lax_numpy.astype(arr, dtype, copy=copy, device=device)

[repeats a few more times]

I believe this warning is safe to ignore, but I’m curious why qml.ctrl upscales to 64 bit precision when qml.U3 does not.

Hi @schance995 ,

It’s not a precision issue. What Is happening is that the control value by default is 1. Since your two qubits start on the \vert 0 \rangle state then none of the controlled operations are actually applied (because you don’t have the right control value) so you end up with a circuit that doesn’t execute anything and thus the probability is 1 for the state \vert 00 \rangle and 0 for the other states.

The code below shows a modified version of your code where the control value is 0, giving you the expected output. Note that the probabilities in both cases won’t be exactly the same because the second controlled operation is acting on a state that is not \vert 00 \rangle, so both circuits are not equivalent.

import jax
import jax.numpy as jnp
import pennylane as qml

jax.config.update('jax_platform_name', 'cpu')
jax.config.update("jax_enable_x64", True) # change this to True

n_qubits = 2
n_layers = 1
dev = qml.device('default.qubit', wires=n_qubits)

def CU3(phi, theta, delta, wires):
    qml.ctrl(qml.U3, control=(wires[0]), control_values=0)(phi, theta, delta, wires=wires[1]) # Add the control value

@qml.qnode(dev)
def u3_circuit(phis_u3, thetas_u3, deltas_u3, phis_cu3, thetas_cu3, deltas_cu3):
    for layer in range(n_layers):
        for i in range(n_qubits):
            qml.U3(phis_u3[layer, i],
                   thetas_u3[layer, i],
                   deltas_u3[layer, i],
                   wires=i)
    return qml.probs(wires=list(range(n_qubits)))

qml.draw_mpl(u3_circuit)(*jnp.ones((6, n_layers, n_qubits))) # draw the circuit
print('U3 result')
u3_result = u3_circuit(*jnp.ones((6, n_layers, n_qubits))).block_until_ready()
print(u3_result)

@qml.qnode(dev)
def cu3_circuit(phis_u3, thetas_u3, deltas_u3, phis_cu3, thetas_cu3, deltas_cu3):
    for layer in range(n_layers):
        for i in range(n_qubits):
            CU3(phis_cu3[layer, i],
                thetas_cu3[layer, i],
                deltas_cu3[layer, i],
                wires=[i, (i+1) % n_qubits])
            
    return qml.probs(wires=list(range(n_qubits)))

qml.draw_mpl(cu3_circuit)(*jnp.ones((6, n_layers, n_qubits))) # draw the circuit
print('CU3 result')
cu3_result = cu3_circuit(*jnp.ones((6, n_layers, n_qubits))).block_until_ready()
print(cu3_result)

Hi @CatalineAlbornoz, thanks for your reply. Indeed, I chose the weights specifically to highlight the numeric precision change from complex64 to complex128 with the use of the controlled U3 gate. In the code example above, the U3 gate is computed with complex64. So I expected that the controlled U3 gate would also compute to complex64, especially because no computation should be occuring. Instead, the controlled U3 casts to complex128. So my concern is that there is something in the code for the qml.ctrl function that causes an upcast to complex128, when I expected that the precision between the original and controlled version to be the same. I hope this clarifies my original question.

Ah, sorry, I now understand your question @schance995 .

It’s strange because qml.ctrl doesn’t change the numeric precision all the time. For example the following code runs without any warning.

import jax
import jax.numpy as jnp
import pennylane as qml

jax.config.update("jax_enable_x64", False) 

@qml.qnode(qml.device('default.qubit', wires=range(2)))
def circuit(phi, theta, omega):
    qml.ctrl(qml.Rot, control=(0), control_values=1)(phi, theta, omega, wires=1) 
    return qml.expval(qml.Z(0))

phi = theta = omega = jnp.array(1.2)
qml.draw_mpl(circuit)(phi, theta, omega); # draw the circuit
circuit(phi, theta, omega)

But if you add one extra layer of control then it does show the issue.

jax.config.update("jax_enable_x64", False) 

@qml.qnode(qml.device('default.qubit', wires=range(3)))
def circuit(phi, theta, omega):
    qml.ctrl(qml.ctrl(qml.Rot, control=(1), control_values=1), control=(0), control_values=1)(phi, theta, omega, wires=2) 
    return qml.expval(qml.Z(0))

phi = theta = omega = jnp.array(1.2)
qml.draw_mpl(circuit)(phi, theta, omega); # draw the circuit
circuit(phi, theta, omega)
/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/array_methods.py:118: UserWarning: Explicitly requested dtype <class 'jax.numpy.complex128'> requested in astype is not available, and will be truncated to dtype complex64. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  return lax_numpy.astype(self, dtype, copy=copy, device=device)
/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/array_methods.py:118: UserWarning: Explicitly requested dtype <class 'jax.numpy.complex128'> requested in astype is not available, and will be truncated to dtype complex64. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  return lax_numpy.astype(self, dtype, copy=copy, device=device)
Array(1., dtype=float32)

If you can use qml.Rot for now that might solve your issues, but we’ll investigate to see what’s going on.

Thanks for flagging this.

1 Like

I’ve open this bug report to capture this issue.

We will mark this as non-urgent for the moment but please let us know if this starts breaking things or becomes urgent for you @schance995 .

Thanks again for making us aware of this and helping to improve PennyLane!