I am trying to train two models on a GPU cluster: 1. Hybrid quantum-classical reinforcement learning (RL) model and 2. Classical deep RL model. The model is pretty complex and therefore it’s difficult to share the complete code on this forum. When I train the classical model, the training happens on GPU but, when I train the quantum-classical model, I realized that the GPU usage is 0, thus making each iteration extremely slow. I tried debugging the code to figure out the issue and I realized that the input tensor as well as the parameters of the model (classical and quantum) are on GPU. Following are further details of my implementation. I am using V100-16GB GPU for the
dev = qml.device("default.qubit", wires=4) # The qnode is created using the following annotation: # @qml.qnode(dev, interface='torch')
Please let me know if there are any flags that I should change or any tensor that I should evaluate to figure out the reason why GPU is not being used for quantum simulation. Apologies that I couldn’t share a working code snippet to reproduce the issue.
Summary: PennyLane is a Python quantum machine learning library by Xanadu Inc.
Home-page: GitHub - PennyLaneAI/pennylane: PennyLane is a cross-platform Python library for differentiable programming of quantum computers. Train a quantum computer the same way as a neural network.
License: Apache License 2.0
Requires: appdirs, autograd, autoray, cachetools, networkx, numpy, pennylane-lightning, requests, retworkx, scipy, semantic-version, toml
Required-by: PennyLane-Lightning, PennyLane-Lightning-GPU
Platform info: Linux-5.15.0-1028-nvidia-x86_64-with-glibc2.17
Python version: 3.8.16
Numpy version: 1.22.3
Scipy version: 1.7.3
- lightning.gpu (PennyLane-Lightning-GPU-0.30.0)
- lightning.qubit (PennyLane-Lightning-0.28.1)
- default.gaussian (PennyLane-0.28.0)
- default.mixed (PennyLane-0.28.0)
- default.qubit (PennyLane-0.28.0)
- default.qubit.autograd (PennyLane-0.28.0)
- default.qubit.jax (PennyLane-0.28.0)
- default.qubit.tf (PennyLane-0.28.0)
- default.qubit.torch (PennyLane-0.28.0)
- default.qutrit (PennyLane-0.28.0)
- null.qubit (PennyLane-0.28.0)