Speeding up grad computation

Hi, continuing on this thread I have some benchmarks regarding gradient computation comparing JIT vs no JIT (in Jax). Using backprop we get gradients faster ~3 orders of magnitude. Using a finite number of shots, with or without JIT, its quite a bit slower:

JIT True | shots None | time taken 5.002456e-05
JIT False | shots None | time taken 2.144583e-02

JIT True | shots 500 | time taken 7.670196e-02
JIT False | shots 500 | time taken 7.490716e-02

Here is the code snippet. I hope this helps to get a clear picture and for some of the future benchmarking:

import pennylane as qml


import jax
import jax.numpy as jnp
from jax import config

config.update("jax_enable_x64", True)


import time


def timeit(func, params):
    """Time the function.

    Args:
        func (Callable): A function to call.
        params (array): The inputs to the function.

    Returns:
        float: The time taken to run the function.
    """
    tic = time.perf_counter()
    func(params)
    toc = time.perf_counter()
    return toc - tic


N = 5
variational_ansatz = qml.BasicEntanglerLayers
n_layers = 5


def get_grad_fn(shots=None, diff_method='best', jit=True):
    """Get the gradient function with a combination of shots, diff_method and JIT

    Args:
        shots (int, optional): Number of shots. Defaults to None.
        diff_method (str, optional): The method to use for gradient computation.
                                     Defaults to 'best'.
        jit (bool, optional): Should the gradient function be JIT compiled.
                              Defaults to True.

    Returns:
        Callable: A callable function that computes gradients
    """
    dev = qml.device("default.qubit.jax", wires=N, shots=shots)
    variational_ansatz = qml.BasicEntanglerLayers

    @jax.jit
    @qml.qnode(dev, interface="jax", diff_method=diff_method)
    def circuit(params: jnp.array):
        """Variational circuit that we want to optimize

        Args:
            params (jnp.array): _description_

        Returns:
            float: Expecatation value
        """
        variational_ansatz(params, wires=range(N))
        return qml.expval(qml.PauliZ(0))

    grad_x = jax.grad(circuit)

    if jit:
        return jax.jit(grad_x)
    else:
        return grad_x


for shots in [None, 500]:
    for jit in [True, False]:
        grad_x = get_grad_fn(shots=shots, jit=jit)

        key = jax.random.PRNGKey(42)
        x = jax.random.uniform(key, variational_ansatz.shape(n_layers=n_layers, n_wires=N))

        # Run the grad function once to compile
        grad_x(x)

        num_repeat = 100
        times = np.empty(num_repeat)

        for i in range(100):
            x = jax.random.uniform(key, variational_ansatz.shape(n_layers=n_layers, n_wires=N))
            times[i] = timeit(grad_x, x)

        print(f"JIT {jit} | shots {shots} | time taken ", "{:e}".format(jnp.mean(times)))
1 Like