Does qml.kernels.kernel_matrix computation run in parallel in Braket?

TL;DR: does qml.kernels.kernel_matrix leverage parallel execution on AWS Braket, like the gradient-based optimisers do?

Hi, I’ve been using Pennylane for quantum machine learning research for a while now, and in most situations I have found that JAX with JIT compilation is far superior to any other interface I’ve tried, especially when I’m able to batch calculations using jax.vmap. I can then batch over the input data only, and not the trainable parameters, which makes sense when calculating loss over a batch of data.

I’ve been working on quantum kernel estimation recently, and I’m investigating how best to scale my simulations up. I wrote my own batched kernel matrix calculation as the built in qml.kernel does not leverage JAX optimally. If I’m not doing kernel alignment, I seem to be able to run pretty large batches using JAX. Doing kernel alignment (i.e. calculating gradients), it seems I have to use rather small batch sizes anyway, or the jit compilation just stalls.

Looking at scaling beyond 20 qubits, I’m trying out AWS Braket SV1 simulator, and I see that the documentation/tutorials show that at least the built-in optimisation utilises the parallel execution on Braket. I’m wondering whether the qml.kernels functions do that as well? I could maybe find out by digging into the code, but I just wanted to ask first.

Hi @einar_bui_magnusson, welcome to the forum!

You’re right about the fact that the built in qml.kernel doesn’t leverage JAX optimally.

The AWS Braket plugin leverages only the parallel execution of gradients and gradient-based optimizers. I couldn’t find anything related to the parallel execution of kernel_matrix.

We will take this into account and maybe we can make it available in future releases.

Thank you very much for asking this question!


Since kernel calculation using quantum circuits is not complicated, you may implement it by using Braket SDK.
Then, batch execution by SV1 is quite simple.
(Here is a sample from braket tutorial.)

circuits = [bell for _ in range(5)]
batch = device.run_batch(circuits, s3_folder, shots=100)
print(batch.results()[0].measurement_counts)  # The result of the first task in the batch

Note: Since SV1 is remote simulator, the execution including communication delay would be order of seconds.
If the number of input data is large, it takes a lot of time even if you use SV1 remote simulator.

Thanks for this suggestion @Kuma-quant!

Would you like to contribute this feature yourself? We can guide you on the process if you need it. If not just let me know.