Issue with parallelization/broadcasting with Pytorch

Dear Pennylane community,

I hope that you are doing well.

I was trying to run a lightning.qubit circuit with the Pytorch interface. In my circuit, I run the following command:

x = torch.vmap(self.q_circ_torch)(self.scaling * x).to(DEVICE)

where self.q_circ_torch is a qml.qnn.qml.qnn.TorchLayer. This line causes the following error message:

Traceback (most recent call last):
File “/usr/local/lib/python3.11/dist-packages/joblib/externals/loky/process_executor.py”, line 490, in _process_worker
r = call_item()
^^^^^^^^^^^
File “/usr/local/lib/python3.11/dist-packages/joblib/externals/loky/process_executor.py”, line 291, in call
return self.fn(*self.args, **self.kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/usr/local/lib/python3.11/dist-packages/joblib/parallel.py”, line 607, in call
return [func(*args, **kwargs) for func, args, kwargs in self.items]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/usr/local/lib/python3.11/dist-packages/joblib/parallel.py”, line 607, in
return [func(*args, **kwargs) for func, args, kwargs in self.items]
^^^^^^^^^^^^^^^^^^^^^
File “/opt/qvc/main.py”, line 20, in optimize_analyze_seed
study.optimize(
File “/usr/local/lib/python3.11/dist-packages/optuna/study/study.py”, line 490, in optimize
_optimize(
File “/usr/local/lib/python3.11/dist-packages/optuna/study/_optimize.py”, line 67, in _optimize
_optimize_sequential(
File “/usr/local/lib/python3.11/dist-packages/optuna/study/_optimize.py”, line 164, in _optimize_sequential
frozen_trial_id = _run_trial(study, func, catch)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/usr/local/lib/python3.11/dist-packages/optuna/study/_optimize.py”, line 262, in _run_trial
raise func_err
File “/usr/local/lib/python3.11/dist-packages/optuna/study/_optimize.py”, line 205, in _run_trial
value_or_values = func(trial)
^^^^^^^^^^^
File “/opt/qvc/main.py”, line 21, in
lambda trial: objective(trial, f"study={STUDY_NAME}"),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/opt/qvc/hpo_dataset_load.py”, line 91, in objective
list_val_acc, lr_schedule = train_model(model, train_loader, val_loader,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/opt/qvc/train_eval.py”, line 129, in train_model
outputs = model(inputs)
^^^^^^^^^^^^^
File “/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py”, line 1775, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py”, line 1786, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/opt/qvc/qcs_qre.py”, line 156, in forward
x = torch.vmap(self.q_circ_torch)(self.scaling * x).to(DEVICE)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/usr/local/lib/python3.11/dist-packages/torch/_functorch/apis.py”, line 208, in wrapped
return vmap_impl(
^^^^^^^^^^
File “/usr/local/lib/python3.11/dist-packages/torch/_functorch/vmap.py”, line 282, in vmap_impl
return _flat_vmap(
^^^^^^^^^^^
File “/usr/local/lib/python3.11/dist-packages/torch/_functorch/vmap.py”, line 432, in _flat_vmap
batched_outputs = func(*batched_inputs, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py”, line 1775, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py”, line 1786, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/usr/local/lib/python3.11/dist-packages/pennylane/qnn/torch.py”, line 407, in forward
results = self._evaluate_qnode(inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/usr/local/lib/python3.11/dist-packages/pennylane/qnn/torch.py”, line 433, in _evaluate_qnode
res = self.qnode(**kwargs)
^^^^^^^^^^^^^^^^^^^^
File “/usr/local/lib/python3.11/dist-packages/pennylane/workflow/qnode.py”, line 863, in call
return self._impl_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/usr/local/lib/python3.11/dist-packages/pennylane/workflow/qnode.py”, line 836, in _impl_call
res = execute(
^^^^^^^^
File “/usr/local/lib/python3.11/dist-packages/pennylane/workflow/execution.py”, line 238, in execute
results = run(tapes, device, config, inner_transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/usr/local/lib/python3.11/dist-packages/pennylane/workflow/run.py”, line 345, in run
results = ml_execute(tapes, execute_fn, jpc, device=device)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/usr/local/lib/python3.11/dist-packages/pennylane/workflow/interfaces/torch.py”, line 238, in execute
return ExecuteTapes.apply(kwargs, *parameters)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/usr/local/lib/python3.11/dist-packages/pennylane/workflow/interfaces/torch.py”, line 89, in new_apply
flat_out = orig_apply(out_struct_holder, *inp)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/usr/local/lib/python3.11/dist-packages/torch/autograd/function.py”, line 584, in apply
raise RuntimeError(
RuntimeError: In order to use an autograd.Function with functorch transforms (vmap, grad, jvp, jacrev, …), it must override the setup_context staticmethod. For more details, please see Extending torch.func with autograd.Function — PyTorch main documentation

I verified that the circuit works with default.qubit, so it is very likely that the issue comes from lightning.qubit. Are there any recommandations for replacing torch.vmap with another operation with the same effect? Can a similar parallelization operation be used with another interface?

I have seen the following relevant post: Incompatible function arguments error on lightning.qubit with JAX - PennyLane Help - Discussion Forum — PennyLane, which shows that broadcasting with jax.vmap was not implemented for JAX since it didn’t show an advantage. Is it also the case for Pytorch?

Thank you in advance for your help!

Taha

Hi @Taha ,

It’s hard to tell what the problem is without seeing a minimal reproducible example. If you could share one that would be the best way for us to see what’s happening.

From the error traceback it looks like you might be trying to use autograd data (e.g. using NumPy data) within the Torch interface. Since it’s only possible to have one interface at a time, this could be the cause for the issue.

That being said, if you want to move to using JAX that would be a good idea. We’ve added a lot more support for JAX within the PennyLane ecosystem over the past few years so if you can easily switch to JAX I would recommend that.

If you try out any of these potential solutions, could you please let us know here if they work or not?

In case they don’t work then please share a minimal reproducible example so that we can further investigate the issue.

I hope this helps!