Hi, continuing on this thread I have some benchmarks regarding gradient computation comparing JIT vs no JIT (in Jax). Using backprop we get gradients faster ~3 orders of magnitude. Using a finite number of shots, with or without JIT, its quite a bit slower:
JIT True | shots None | time taken 5.002456e-05
JIT False | shots None | time taken 2.144583e-02
JIT True | shots 500 | time taken 7.670196e-02
JIT False | shots 500 | time taken 7.490716e-02
Here is the code snippet. I hope this helps to get a clear picture and for some of the future benchmarking:
import pennylane as qml
import jax
import jax.numpy as jnp
from jax import config
config.update("jax_enable_x64", True)
import time
def timeit(func, params):
"""Time the function.
Args:
func (Callable): A function to call.
params (array): The inputs to the function.
Returns:
float: The time taken to run the function.
"""
tic = time.perf_counter()
func(params)
toc = time.perf_counter()
return toc - tic
N = 5
variational_ansatz = qml.BasicEntanglerLayers
n_layers = 5
def get_grad_fn(shots=None, diff_method='best', jit=True):
"""Get the gradient function with a combination of shots, diff_method and JIT
Args:
shots (int, optional): Number of shots. Defaults to None.
diff_method (str, optional): The method to use for gradient computation.
Defaults to 'best'.
jit (bool, optional): Should the gradient function be JIT compiled.
Defaults to True.
Returns:
Callable: A callable function that computes gradients
"""
dev = qml.device("default.qubit.jax", wires=N, shots=shots)
variational_ansatz = qml.BasicEntanglerLayers
@jax.jit
@qml.qnode(dev, interface="jax", diff_method=diff_method)
def circuit(params: jnp.array):
"""Variational circuit that we want to optimize
Args:
params (jnp.array): _description_
Returns:
float: Expecatation value
"""
variational_ansatz(params, wires=range(N))
return qml.expval(qml.PauliZ(0))
grad_x = jax.grad(circuit)
if jit:
return jax.jit(grad_x)
else:
return grad_x
for shots in [None, 500]:
for jit in [True, False]:
grad_x = get_grad_fn(shots=shots, jit=jit)
key = jax.random.PRNGKey(42)
x = jax.random.uniform(key, variational_ansatz.shape(n_layers=n_layers, n_wires=N))
# Run the grad function once to compile
grad_x(x)
num_repeat = 100
times = np.empty(num_repeat)
for i in range(100):
x = jax.random.uniform(key, variational_ansatz.shape(n_layers=n_layers, n_wires=N))
times[i] = timeit(grad_x, x)
print(f"JIT {jit} | shots {shots} | time taken ", "{:e}".format(jnp.mean(times)))