Thanks for pointing me to this JAX+Optax tutorial. I have modified my code to use both.
It works with dev = qml.device('default.qubit', wires=n_qubits)
but crashes with dev = qml.device('default.qubit', wires=n_qubits, shots=1000)
The error is ValueError: probabilities do not sum to 1
. Below is the full dump
M: verify code sanity, X: (300, 2)
Traceback (most recent call last):
File "/PennyLane/toys/./toy_opt_speed_jax_optax.py", line 62, in <module>
val=loss_fn(params, X,Y)
File "/usr/local/lib/python3.10/dist-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py", line 248, in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
File "/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py", line 143, in _python_pjit_helper
out_flat = pjit_p.bind(*args_flat, **params)
File "/usr/local/lib/python3.10/dist-packages/jax/_src/core.py", line 2727, in bind
return self.bind_with_trace(top_trace, args, params)
File "/usr/local/lib/python3.10/dist-packages/jax/_src/core.py", line 423, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/usr/local/lib/python3.10/dist-packages/jax/_src/core.py", line 913, in process_primitive
return primitive.impl(*tracers, **params)
File "/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py", line 1409, in _pjit_call_impl
return xc._xla.pjit(
File "/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py", line 1392, in call_impl_cache_miss
out_flat, compiled = _pjit_call_impl_python(
File "/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py", line 1348, in _pjit_call_impl_python
return compiled.unsafe_call(*args), compiled
File "/usr/local/lib/python3.10/dist-packages/jax/_src/profiler.py", line 336, in wrapper
return func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/pxla.py", line 1201, in __call__
results = self.xla_executable.execute_sharded(
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Generated function failed: CpuCallback error: ValueError: probabilities do not sum to 1
At:
numpy/random/_generator.pyx(828): numpy.random._generator.Generator.choice
/usr/local/lib/python3.10/dist-packages/pennylane/devices/qubit/sampling.py(443): <listcomp>
/usr/local/lib/python3.10/dist-packages/pennylane/devices/qubit/sampling.py(443): sample_state
/usr/local/lib/python3.10/dist-packages/pennylane/devices/qubit/sampling.py(297): _measure_with_samples_diagonalizing_gates
/usr/local/lib/python3.10/dist-packages/pennylane/devices/qubit/sampling.py(198): measure_with_samples
/usr/local/lib/python3.10/dist-packages/pennylane/devices/qubit/simulate.py(200): measure_final_state
/usr/local/lib/python3.10/dist-packages/pennylane/devices/qubit/simulate.py(263): simulate
/usr/local/lib/python3.10/dist-packages/pennylane/devices/default_qubit.py(554): <genexpr>
Attached is the reproducer:
import numpy as cnp
import pennylane as qml
import jax
from jax import numpy as jnp
import optax
from time import time
n_sampl = 300
n_feature=2; n_qubits=3; layers=1; steps=10
#dev = qml.device('default.qubit', wires=n_qubits) # works
dev = qml.device('default.qubit', wires=n_qubits, shots=1000) # crashes
#.... input data
Xu= cnp.random.uniform(-1, 1, size=(n_sampl, n_feature) )
Xa=cnp.arccos(Xu)
X = jnp.array(Xa )
Y = jnp.where(Xu[:, 0] * Xu[:, 1] > 0, 1, -1) # Compute labels
#... trainable params
params = 0.2 * jnp.array( cnp.random.random(size=(layers, n_qubits,3)) )
@qml.qnode(dev)
def circuit(params,x):
qml.RY(x[0], wires=0)
qml.RY(x[1], wires=1)
for layer in range(layers): # EfficientSU2 ansatz
qml.Barrier()
for qubit in range(n_qubits):
qml.RX(params[layer, qubit, 0], wires=qubit)
qml.RY(params[layer, qubit, 1], wires=qubit)
qml.RZ(params[layer, qubit, 2], wires=qubit)
for qubit in range(n_qubits):
qml.CNOT(wires=[qubit, (qubit + 1) % n_qubits])
return qml.expval(qml.PauliZ(2))
print(qml.draw(circuit, decimals=2)(params,X[0]), '\n')
#... classical ML utility func
@jax.jit
def loss_fn( params, X, Y): # vectorized code
pred = circuit(params, X.T) # vectorized execution
cost = jnp.mean((Y - pred)) ** 2)
return cost
print('M: verify code sanity, X:',X.shape)
T0=time()
val=loss_fn(params, X,Y) # <=== CRASH IS HERE
durT=time()-T0
print('elaT=%.1f sec, one loss_fn:%s'%(durT,val))
print('grad:',jax.grad(loss_fn)(params,X,Y))
#... run optimizer
opt = optax.adam(learning_rate=0.3)
opt_state = opt.init(params)
def update_step(opt, params, opt_state, data, targets):
loss_val, grads = jax.value_and_grad(loss_fn)(params, data, targets)
updates, opt_state = opt.update(grads, opt_state)
params = optax.apply_updates(params, updates)
return params, opt_state, loss_val
for i in range(100):
params, opt_state, loss_val = update_step(opt, params, opt_state, X,Y)
if i % 5 == 0:
print(f"Step: {i} Loss: {loss_val}")
How shell I proceed now to run optimization on shot-based device?