Error with Parameter-Shift Gradient Computation in PennyLane when Integrating with PyTorch & PyTorch Lightning

Hello!

I’m encountering an issue when integrating PyTorch Lightning with PennyLane in my hybrid quantum neural network model.

NotImplementedError: Computing the gradient of broadcasted tapes with respect to the broadcasted parameters using the parameter-shift rule gradient transform is currently not supported. See #4462 for details.

Complete Code Example:

Below is a self-contained code example that reproduces the behavior I’m observing. I’ve included comments to explain each part of the code. Please note that this is some part of my project that I have adapted for the purpose of this example. I’ve provided minimal implementations to ensure the code runs independently.

# File: VQC.py
import math
import torch
import torch.nn as nn
import pennylane as qml
import numpy as np


def create_qnode(n_qubits, circuit, weight_shapes):
    dev = qml.device("lightning.qubit", wires=n_qubits)
    qnode_obj = qml.QNode(circuit, dev, interface="torch", diff_method="parameter-shift"
)
    return qml.qnn.TorchLayer(qnode_obj, weight_shapes)

def strongly_entangling_circuit(n_qubits, n_qdepth, rotation="X"):
    weight_shapes = {"weights": (n_qdepth, n_qubits, 3)}
    
    def circuit(inputs, weights):
        qml.AngleEmbedding(inputs, wires=range(n_qubits), rotation=rotation)
        qml.StronglyEntanglingLayers(weights, wires=range(n_qubits), ranges=np.ones(3, dtype=int))
        return [qml.expval(qml.PauliY(wires=i)) for i in range(n_qubits)]
    
    return circuit, weight_shapes

class VQC(nn.Module):
    def __init__(self, size_in, n_qubits, n_qdepth, circuit_type="strongly_entangling"):
        super(VQC, self).__init__()
        self.size_in = size_in
        self.n_qubits = n_qubits
        self.n_qdepth = n_qdepth
        
        # Choose the appropriate circuit type
        if circuit_type == "strongly_entangling":
            self.circuit, self.weight_shapes = strongly_entangling_circuit(n_qubits, n_qdepth)
        else:
            raise ValueError("Invalid circuit type provided.")
        
        self.layers = nn.ModuleList(
            [create_qnode(self.n_qubits, self.circuit, self.weight_shapes) for _ in range(math.ceil(size_in / n_qubits))]
        )

    def forward(self, x):
        # Split the input into the number of qubits
        x_split = torch.split(x, self.n_qubits, dim=1)
        outputs = [layer(inputs) for layer, inputs in zip(self.layers, x_split)]
        # Concatenate the outputs of individual quantum layers
        return torch.cat(outputs, dim=1)

# File: BaseModel.py
from pytorch_lightning import LightningModule
from torch.optim import Adam, SGD
from torch.nn import CrossEntropyLoss, MSELoss
import os
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning import Trainer

def calculate_accuracy(y_hat, y):
    _, predicted = torch.max(y_hat, dim=1)
    correct = (predicted == y).float()
    accuracy = correct.sum() / len(correct)
    return accuracy

SUPPORTED_OPTIMIZERS = {
    "adam": Adam,
    "sgd": SGD
}

SUPPORTED_LOSS_FUNCTIONS = {
    "cross_entropy": CrossEntropyLoss,
    "mse": MSELoss
}

class BaseModel(LightningModule):
    def __init__(
            self,
            model_name: str,
            batch_size: int = 32,
            learning_rate: float = 0.001,
            max_epochs: int = 10,
            optimizer: str = "adam",
            loss_function: str = "cross_entropy",
            seed: int = 42,
            gpus: int = None  # Set to None for CPU training
    ):
        super().__init__()
        self.model_name = model_name
        self.batch_size = batch_size
        self.learning_rate = learning_rate
        self.max_epochs = max_epochs

        if optimizer.lower() not in SUPPORTED_OPTIMIZERS:
            raise ValueError(f"Unsupported optimizer: {optimizer}")
        self.optimizer_class = SUPPORTED_OPTIMIZERS[optimizer.lower()]

        if loss_function.lower() not in SUPPORTED_LOSS_FUNCTIONS:
            raise ValueError(f"Unsupported loss function: {loss_function}")
        self.loss_function = SUPPORTED_LOSS_FUNCTIONS[loss_function.lower()]()

        self.seed = seed
        self.gpus = gpus

        self._configure_trainer()
        
    def _configure_trainer(self):
        logger = TensorBoardLogger(f"tensorboard", name=self.model_name)
        self.trainer = Trainer(
            max_epochs=self.max_epochs,
            logger=logger,
            default_root_dir="tensorboard"
        )

    def configure_optimizers(self):
        return self.optimizer_class(self.parameters(), lr=self.learning_rate)

    def forward(self, x):
        raise NotImplementedError("Forward method must be implemented")

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss_function(y_hat, y)
        acc = calculate_accuracy(y_hat, y)

        self.log_dict({
            "train_loss": loss, 
            "train_acc": acc
            },
            on_step=False,
            on_epoch=True,
            prog_bar=True)
        
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss_function(y_hat, y)
        acc = calculate_accuracy(y_hat, y)

        self.log_dict({
            "val_loss": loss, 
            "val_acc": acc
            },
            on_step=False,
            on_epoch=True,
            prog_bar=True)
        
        return loss

    def fit(self, train_dataloaders, val_dataloaders):
        self.trainer.fit(self, train_dataloaders, val_dataloaders)

# File: HQNN_Parallel.py
from torch.nn import Conv2d, Linear, MaxPool2d, ReLU, BatchNorm2d, LogSoftmax
from torch import flatten
import torch
from torch.optim import Adam
import time
import torch.nn as nn
from torch.nn import Module
import math

class HQNN_Parallel(BaseModel):
    def __init__(self, in_channels, classes, **kwargs):
        super().__init__(model_name='HQNN_Parallel', **kwargs)

        # Initialize first set of CONV => RELU => POOL layers
        self.conv1 = Conv2d(in_channels=in_channels, out_channels=16, kernel_size=(5,5), padding=2)
        self.batchNorm1 = BatchNorm2d(16)
        self.relu1 = ReLU()
        self.maxpool1 = MaxPool2d(kernel_size=(2,2), stride=(2,2))

        # Initialize second set of CONV => RELU => POOL layers
        self.conv2 = Conv2d(in_channels=16, out_channels=32, kernel_size=(5,5), padding=2)
        self.batchNorm2 = BatchNorm2d(32)
        self.relu2 = ReLU()
        self.maxpool2 = MaxPool2d(kernel_size=(2,2), stride=(2,2))

        # Initialize first (and only) set of FC => RELU layers
        self.fc1 = Linear(in_features=1568, out_features=20) 
        self.relu3 = ReLU()

        # Initialize the quantum layer
        self.qlayer1 = VQC(
            size_in=20,
            n_qubits=5,
            n_qdepth=3,
        )

        # Initialize our softmax classifier
        self.fc2 = Linear(in_features=20, out_features=classes)
        self.logSoftmax = LogSoftmax(dim=1)

    def forward(self, x):
        # ==== Convolutional layers
        x = self.conv1(x)
        x = self.batchNorm1(x)
        x = self.relu1(x)
        x = self.maxpool1(x)

        x = self.conv2(x)
        x = self.batchNorm2(x)
        x = self.relu2(x)
        x = self.maxpool2(x)

        x = flatten(x, 1)  # flatten the output from the previous layer

        # ==== Fully Connected layer
        x = self.fc1(x)
        x = self.relu3(x)

        # ==== Quantum layer
        x = self.qlayer1(x)

        # ==== Softmax classifier
        x = self.fc2(x)
        output = self.logSoftmax(x)

        return output  # return the output predictions

    def fit(self, trainDataLoader, valDataLoader):
        self.trainer.fit(self, trainDataLoader, valDataLoader)

# File: Main.py
from torch.utils.data import DataLoader, TensorDataset
import torch

if __name__ == "__main__":
    # Create dummy datasets
    X_train = torch.randn(1500, 1, 28, 28)
    y_train = torch.randint(0, 10, (1500,))
    X_val = torch.randn(300, 1, 28, 28)
    y_val = torch.randint(0, 10, (300,))

    train_dataset = TensorDataset(X_train, y_train)
    val_dataset = TensorDataset(X_val, y_val)

    train_loader = DataLoader(train_dataset, batch_size=64)
    val_loader = DataLoader(val_dataset, batch_size=64)

    # Initialize and train the model
    model = HQNN_Parallel(in_channels=1, classes=10)
    model.fit(train_loader, val_loader)

Full Error Message:

Traceback (most recent call last):
File "C:\Users\___\src\Main.py", line 74, in <module>
  model1.fit(train_dataset, val_dataset)
File "C:\Users\___\src\models\Hybrid\HQNN_Parallel.py", line 123, in fit
  loss.backward()
File "C:\Users\___\.conda\envs\QAML\lib\site-packages\torch\_tensor.py", line 581, in backward
  torch.autograd.backward(
File "C:\Users\___\.conda\envs\QAML\lib\site-packages\torch\autograd\__init__.py", line 347, in backward    
  _engine_run_backward(
File "C:\Users\___\.conda\envs\QAML\lib\site-packages\torch\autograd\graph.py", line 825, in _engine_run_backward
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass       
File "C:\Users\___\.conda\envs\QAML\lib\site-packages\torch\autograd\function.py", line 307, in apply       
  return user_fn(self, *args)
File "C:\Users\___\.conda\envs\QAML\lib\site-packages\pennylane\workflow\interfaces\torch.py", line 101, in 
new_backward
  grad_inputs = orig_bw(ctx, *grad_outputs)
File "C:\Users\___\.conda\envs\QAML\lib\site-packages\pennylane\workflow\interfaces\torch.py", line 186, in 
backward
  vjps = ctx.jpc.compute_vjp(ctx.tapes, dy)
File "C:\Users\___\.conda\envs\QAML\lib\site-packages\pennylane\workflow\jacobian_products.py", line 300, in compute_vjp
  vjp_tapes, processing_fn = qml.gradients.batch_vjp(
File "C:\Users\___\.conda\envs\QAML\lib\site-packages\pennylane\gradients\vjp.py", line 502, in batch_vjp   
  g_tapes, fn = vjp(tape, dy, gradient_fn, gradient_kwargs=gradient_kwargs)
File "C:\Users\___\.conda\envs\QAML\lib\site-packages\pennylane\gradients\vjp.py", line 363, in vjp
  gradient_tapes, fn = gradient_fn(tape, **gradient_kwargs)
File "C:\Users\___\.conda\envs\QAML\lib\site-packages\pennylane\transforms\core\transform_dispatcher.py", line 100, in __call__
  intermediate_tapes, post_processing_fn = self._transform(
File "C:\Users\___\.conda\envs\QAML\lib\site-packages\pennylane\gradients\parameter_shift.py", line 1114, in param_shift
  assert_no_trainable_tape_batching(tape, transform_name)
File "C:\Users\___\.conda\envs\QAML\lib\site-packages\pennylane\gradients\gradient_transform.py", line 97, in assert_no_trainable_tape_batching

raise NotImplementedError(
  NotImplementedError: Computing the gradient of broadcasted tapes with respect to the broadcasted parameters using the parameter-shift rule gradient transform is currently not supported. See #4462 for details.

Pennylane information:

Name: PennyLane
Version: 0.39.0
Summary: PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
Home-page: https://github.com/PennyLaneAI/pennylane
Author:
Author-email:
License: Apache License 2.0
Location: c:\users\__\.conda\envs\qaml\lib\site-packages
Requires: appdirs, autograd, autoray, cachetools, networkx, numpy, packaging, pennylane-lightning, requests, rustworkx, scipy, toml, typing-extensionsRequired-by: PennyLane_Lightning

Platform info:           Windows-10-10.0.19045-SP0
Python version:          3.10.15
Numpy version:           1.26.4
Scipy version:           1.14.1
Installed devices:
- default.clifford (PennyLane-0.39.0)
- default.gaussian (PennyLane-0.39.0)
- default.mixed (PennyLane-0.39.0)
- default.qubit (PennyLane-0.39.0)
- default.qutrit (PennyLane-0.39.0)
- default.qutrit.mixed (PennyLane-0.39.0)
- default.tensor (PennyLane-0.39.0)
- null.qubit (PennyLane-0.39.0)
- reference.qubit (PennyLane-0.39.0)
- lightning.qubit (PennyLane_Lightning-0.39.0)

Installed Packages:

absl-py                        2.1.0
adversarial-robustness-toolbox 1.18.2
aiohappyeyeballs               2.4.3
aiohttp                        3.11.7
aiosignal                      1.3.1
appdirs                        1.4.4
async-timeout                  5.0.1
attrs                          24.2.0
autograd                       1.7.0
autoray                        0.7.0
cachetools                     5.5.0
certifi                        2024.8.30
charset-normalizer             3.4.0
clarabel                       0.9.0
colorama                       0.4.6
contourpy                      1.3.1
cvxpy                          1.6.0
cycler                         0.12.1
ffmpeg-python                  0.2.0
filelock                       3.16.1
fonttools                      4.55.0
frozenlist                     1.5.0
fsspec                         2024.10.0
future                         1.0.0
grpcio                         1.68.0
idna                           3.10
Jinja2                         3.1.4
joblib                         1.4.2
kiwisolver                     1.4.7
kornia                         0.7.4
kornia_rs                      0.1.7
lightning                      2.4.0
lightning-utilities            0.11.9
Markdown                       3.7
MarkupSafe                     3.0.2
matplotlib                     3.9.2
mpmath                         1.3.0
multidict                      6.1.0
networkx                       3.4.2
numpy                          1.26.4
opencv-python                  4.10.0.84
osqp                           0.6.7.post3
packaging                      24.2
PennyLane                      0.39.0
PennyLane_Lightning            0.39.0
pillow                         11.0.0
pip                            24.3.1
propcache                      0.2.0
protobuf                       5.28.3
pyparsing                      3.2.0
python-dateutil                2.9.0.post0
pytorch-lightning              2.4.0
PyYAML                         6.0.2
qdldl                          0.1.7.post4
requests                       2.32.3
rustworkx                      0.15.1
scikit-learn                   1.5.2
scipy                          1.14.1
scs                            3.2.7
setuptools                     75.6.0
six                            1.16.0
sympy                          1.13.1
tensorboard                    2.18.0
tensorboard-data-server        0.7.2
threadpoolctl                  3.5.0
toml                           0.10.2
torch                          2.5.1
torchmetrics                   1.6.0
torchvision                    0.20.1
tqdm                           4.67.1
typing_extensions              4.12.2
urllib3                        2.2.3
Werkzeug                       3.1.3
wheel                          0.45.1
yarl                           1.18.0

What I’ve Tried:

  • Uninstalled and reinstalled packages
  • Tried downgrading pennylane to version 0.38
  • Changing the device from lightning.qubit to default.qubit
  • Simplified my model to avoid using Pytorch Lightning which I though at first was the problem.
  • Multiple other smaller tweaks

Thank you for your assistance!

Hi @Lazarus ,

I’m sorry to hear you’ve been struggling and trying out a bunch of things.

I think the issue here is with TorchLayer.

A few months ago someone had a similar problem so I created this example where I create a class that inherits from nn.Module, instead of using TorchLayer. What I’d suggest is that you try doing something similar where you apply the quantum circuit to each individual element of the batch.

Let me know if this works for you!

1 Like