Recreating Pennylane demo with innacurate results? Plus MPS acceleration is not compatible!

Hi there,
So I am trying to recreate the first half of the following penny lane demo:Quantum Circuit Born Machines | PennyLane Demos. However, I would like to use PyTorch instead of jax as used here. Below is the code I have written to attempt to do this which mostly follows the tutorial, I have also tried to use applic silicon (MPS) acceleration but have ran into an error when i run ‘on device’:

# Check if MPS device is available
    device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
    print(f'Using device: {device}')

    class MMD:

        def __init__(self, scales, space):
            gammas = 1 / (2 * (scales**2))
            sq_dists = np.abs(space[:, None] - space[None, :]) ** 2
            self.K = sum(np.exp(-gamma * sq_dists) for gamma in gammas) / len(scales)
            self.K = torch.tensor(self.K, dtype=torch.float64).to(device)
            self.scales = scales

        def k_expval(self, px, py):
            return torch.matmul(px, torch.matmul(self.K, py))

        def __call__(self, px, py):
            pxy = px - py
            return self.k_expval(pxy, pxy)
            
    class QCBM:

        def __init__(self, circ, mmd, py):
            self.circ = circ
            self.mmd = mmd
            self.py = torch.tensor(py, dtype=torch.float64).to(device)  # target distribution π(x)

        def mmd_loss(self, params):
            px = self.circ(params)
            return self.mmd(px, self.py), px
            
        def kl_divergence(self, px):
            # Avoid division by zero and handle log(0) cases
            qcbm_probs = px.clone().detach()
            target_probs = self.py
            kl_div = -torch.sum(target_probs * torch.nan_to_num(torch.log(qcbm_probs / target_probs)))
            return kl_div

    def get_bars_and_stripes(n):
        bitstrings = [list(np.binary_repr(i, n))[::-1] for i in range(2**n)]
        bitstrings = np.array(bitstrings, dtype=int)

        stripes = bitstrings.copy()
        stripes = np.repeat(stripes, n, 0)
        stripes = stripes.reshape(2**n, n * n)

        bars = bitstrings.copy()
        bars = bars.reshape(2**n * n, 1)
        bars = np.repeat(bars, n, 1)
        bars = bars.reshape(2**n, n * n)
        return np.vstack((stripes[0 : stripes.shape[0] - 1], bars[1 : bars.shape[0]]))

    n = 3
    n_qubits = n**2
    dev = qml.device("default.qubit", wires=n_qubits)
    n_layers = 6
    wshape = qml.StronglyEntanglingLayers.shape(n_layers=n_layers, n_wires=n_qubits)
    weights = np.random.random(size=wshape)
    weights = torch.tensor(weights, requires_grad=True, dtype=torch.float64).to(device)

    @qml.qnode(dev, interface='torch')
    def circuit(weights):
        qml.StronglyEntanglingLayers(
            weights=weights, ranges=[1] * n_layers, wires=range(n_qubits)
        )
        return qml.probs()

    data = get_bars_and_stripes(n)
    bitstrings = []
    nums = []
    for d in data:
        bitstrings += ["".join(str(int(i)) for i in d)]
        nums += [int(bitstrings[-1], 2)]
    probs = np.zeros(2**n_qubits)
    probs[nums] = 1 / len(data)
    probs = torch.tensor(probs, dtype=torch.float64).to(device)  # Ensure probs is a Float tensor

    bandwidth = np.array([0.25, 0.5, 1])
    space = np.arange(2**n_qubits)

    mmd = MMD(bandwidth, space)
    qcbm = QCBM(circuit, mmd, probs)

    optimizer = optim.Adam([weights], lr=0.1)

    # Training loop
    num_epochs = 100
    for epoch in range(num_epochs):
        optimizer.zero_grad()
        loss, px = qcbm.mmd_loss(weights)
        loss.backward()
        optimizer.step()
        kl_div = qcbm.kl_divergence(px)
        print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item()}, KL Divergence: {kl_div.item()}')

This code runs smoothly on cpu however I notice after 100 runs, the KL divergence/ loss values are very different from the demo. I understand there will be some difference due to the seed for example but I would expect an order of magnitude difference in expected results simply by using another framework. If anyone could shed light as to why there is a discrepancy that be amazing.
For reference in 100 epochs, i get the following result with the above: Epoch 100/100, Loss: 0.009134388315731927, KL Divergence: 0.38180681247905385 whereas the demo achieves: Step: 90 Loss: 0.0004 KL-div: 0.0755.

Finally, when trying to run on the ‘mps’ device i get the following error:

TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.

This is because qml.probs() returns a troch.float64 object by default and this cannot be changed on the user’s end, are there any options in Pennylane to make the default output a torch.float32.

Hi @Aaron_Thomas,

It is possible that the starting point is unstable meaning that slight differences like using a different seed or device can cause a big difference in the loss. You can test this by trying different seeds or calculating the gradient around that point. If it’s zero or close then this could explain the behaviour.

About MPS I’m not sure if it’s possible to change the output to float64. Let me check with the team and get back to you on this.

Hiya,
I’ve now tested the code across multiple seeds and for longer iterations than the demo and still cannot achieve the demo results. It seems very odd that switching to a torch interface creates an order of magnitude difference regardless of seed, unless something intrinsically different with my implementation?

Also any update on the float.64 issue? It means no mac user can use GPU acceleration using torch interface currently (because apple mps device has its own internal issues clearly). If there was an option for the user to select the output resolution to torch.float32 for example that would be great.

Hi @Aaron_Thomas ,

I tested your code on Linux (Colab) and I got Epoch 100/100, Loss: 0.000976172725706519, KL Divergence: 0.08454529024009044

In Mac with CPU with your code I get Epoch 100/100, Loss: 0.0016277304392055332, KL Divergence: 0.14576009223326064

However I also got this warning:

/var/folders/wm/s9vs09cx1pbdp1hdgp93gq900000gp/T/ipykernel_73308/1003023866.py:28: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  self.py = torch.tensor(py, dtype=torch.float64).to(device)  # target distribution π(x)

So I changed self.py = torch.tensor(py, dtype=torch.float64).to(device) to self.py = py.clone().detach() and now I get the expected results.
Epoch 100/100, Loss: 0.00044723039783525173, KL Divergence: 0.07477580802773291

Let me know if you also get similar results with the same change.

About the MPS issue I don’t have an answer yet. I hope to have one next week.

Hi @Aaron_Thomas,

I’m following up on the MPS issue. TLDR: this device is not supported at the moment.

In PennyLane, we try to have the output dtype (e.g., of qml.probs ) to be the same as the input dtype. If you run the above code with device = torch.device("cpu") , then you can alternate the weights between dtype=torch.float32 and dtype=torch.float64 and you’ll see that the output of the QNode follows suit.

When you try to use the MPS device however the logic breaks down so it’s unclear whether the issue is coming from PennyLane or Torch itself.

Given the complexity of the issue and the fact that it may be coming from Torch itself, it’s better to consider the MPS device as non-supported. You can, if you want, open a GitHub issue detailing this bug/feature request.

I hope you can still get the results you want with the CPU device.

And thanks again for reporting this behaviour.

Hi @Aaron_Thomas ,

One of my colleagues had an idea. It’s not guaranteed to work but you can still try it:

For implicit conversions, I’d suggest trying lightning in 32 bit mode and seeing if the issue is preserved:

dev = qml.device("lightning.qubit", wires=x, c_dtype=np.complex64)

Let us know if the above succeeds or not. If it doesn’t work, then unfortunately there’s not much we can do.

Hi there,
Apologies for the long time to reply back, i really appreciate you having a look at this issue and presenting your solutions. Im unsure why on my system i was getting the results i was for the bars and stripes dataset, I will have another go at trying this out and trying to diagnose what is happening here on my local system.

Regarding the mac ‘mps’ device issue, i will give what you have suggested a go! Hopefully this can solve the problem if not I will additionally report it here to let others know and start a github issue so it can be effectively tracked.

Many thanks once again

No problem @Aaron_Thomas !

Thank you for your reply. Keep us updated on your results!