Hi, I’m trying to run the QCBM tutorial Quantum Circuit Born Machines | PennyLane Demos swapping out the default.qubit device with the default.tensor device How to simulate quantum circuits with tensor networks | PennyLane Demos.
dev = qml.device("default.tensor", method="mps", **kwargs_mps)
# Then from the QCBM tutorial the error appears here:
loss_1, px = qcbm.mmd_loss(weights) # squared MMD
loss_2 = mmd.k_expval(px, px) - 2 * mmd.k_expval(px, probs) + mmd.k_expval(probs, probs)
print(loss_1)
print(loss_2)
Which gives the error:
NotImplementedError: Measurement process probs(wires=[]) currently not supported by default.tensor.
Is there a way around using .probs() for QCBMs?
@qml.qnode(dev)
def circuit(weights):
qml.StronglyEntanglingLayers(weights=weights, ranges=[1] * n_layers, wires=range(n_qubits))
return qml.probs()
Is there a plan to include this for the default.tensor?
I’m using pennylane==0.38.
Hi @kezmcd1903, welcome to the Forum!
I guess you could measure qml.state() and then manually calculate the probabilities by multiplying the probability amplitudes by their complex conjugate.
I don’t think there’s an immediate plan to add support for qml.probs.
Let me know if my suggestion works for you!
1 Like
Hi @CatalinaAlbornoz, thanks for your response. This works for the probability calculation.
If I understood correctly, we return qml.state()
from the circuit and then square it separately.
@qml.qnode(dev)
def circuit(weights):
qml.StronglyEntanglingLayers(
weights=weights, ranges=[1] * n_layers, wires=range(n_qubits)
)
return qml.state()
@jax.jit
def jit_circuit(weights):
state = circuit(weights)
return state*state.conj()
The gradient optimization with jax runs perfectly with default.qubit
but not with default.tensor
:
Cell In[8], line 3
1 @jax.jit
2 def update_step(params, opt_state):
----> 3 (loss_val, qcbm_probs), grads = jax.value_and_grad(qcbm.mmd_loss, has_aux=True)(
4 params
5 )```
TypeError: Custom JVP rule must produce primal and tangent outputs with corresponding shapes and dtypes. Expected complex128[512] (tangent type of complex128[512]) but got float64[512].
I have tried computing the probability using jnp.real( )
, jnp.abs( )**2
, and tried to compute the probability within the circuit but with no success.
What can be done?
Hi @Marioherreroglez , welcome to the Forum!
As the error suggests you’ve got a dtype mismatch issue. The output of state*state.conj()
will be an array of complex numbers. You’ll need to discard the complex part before returning it. You may have complex values in other places in your code too. Check your weights and parameters. If you’re still facing issues could you please let us know if there’s anything else you changed in the code (with respect to the demo) and could you please post your full error traceback?
Thanks!
Hi @CatalinaAlbornoz, thanks for the welcome.
So I actually tested discarding complex parts but did not seem quite clean and was difficult to make Jax like it. I ended up adding a probs function in default_tensor.py, right under the def state
one:
def probs(self, measurementprocess: MeasurementProcess): # pylint: disable=unused-argument
"""Returns the state vector."""
return (abs(self._quimb_circuit.psi.to_dense().ravel())**2).real
(I also had to import ProbabilityMP in the measurement imports and add self.prob in the _get_measurement_function)
Then the notebook works with no changes except of default.qubit → default.tensor.
Convergence isn’t that fast, but does work!
Typo: the documentation comment for the function should be “”“Returns the probability vector, calculated as the squared modulus of each amplitude in the state vector.”“”
Oh wow, thanks for sharing @Marioherreroglez !
Let us know if you run into any other issues in the future, or if you have any suggestions or feedback in general.
1 Like