TL;DR: does qml.kernels.kernel_matrix leverage parallel execution on AWS Braket, like the gradient-based optimisers do?
Hi, I’ve been using Pennylane for quantum machine learning research for a while now, and in most situations I have found that JAX with JIT compilation is far superior to any other interface I’ve tried, especially when I’m able to batch calculations using jax.vmap. I can then batch over the input data only, and not the trainable parameters, which makes sense when calculating loss over a batch of data.
I’ve been working on quantum kernel estimation recently, and I’m investigating how best to scale my simulations up. I wrote my own batched kernel matrix calculation as the built in qml.kernel does not leverage JAX optimally. If I’m not doing kernel alignment, I seem to be able to run pretty large batches using JAX. Doing kernel alignment (i.e. calculating gradients), it seems I have to use rather small batch sizes anyway, or the jit compilation just stalls.
Looking at scaling beyond 20 qubits, I’m trying out AWS Braket SV1 simulator, and I see that the documentation/tutorials show that at least the built-in optimisation utilises the parallel execution on Braket. I’m wondering whether the qml.kernels functions do that as well? I could maybe find out by digging into the code, but I just wanted to ask first.