Dear experts,
I’m trying to set up a circuit with AmplitudeEmbedding + StronglyEntanglingLayers using JAX. I’m using 4 input variables, therefore I’m playing with 2 qubits in the following way:
n_qubits=2
dev = qml.device("default.qubit.jax",wires=n_qubits)
l = 1
w_shape = (l,n_qubits,3)
w_len=l*n_qubits*3
@jax.jit
@qml.qnode(dev,interface="jax")
def circuit(data,weights):
se_weights = weights[:w_len].reshape(w_shape)
AmplitudeEmbedding(data,normalize=True,wires=range(n_qubits))
StronglyEntanglingLayers(se_weights,wires=range(n_qubits))
return qml.expval(qml.PauliZ(0))
vcircuit = jax.vmap(circuit,in_axes=(0,None))
n_cores = jax.local_device_count()
pcircuit = jax.pmap(vcircuit,in_axes=(0,None))
def train(n_events):
datas, labels = get_muon_dataset(n_jets=n_events,balanced=True)
datas = jnp.stack(np.array_split(datas,n_cores))
labels = jnp.stack(np.array_split(labels,n_cores))
def loss(weights):
out = pcircuit(datas,weights)
l = jnp.mean((out-labels)**2)
return l
def accuracy(weights):
out = pcircuit(datas,weights)
return jnp.mean(jnp.sign(out) == jnp.sign(labels))
def callbacks(weights):
l = loss(weights)
a = accuracy(weights)
print("loss: ", l, " acc: ", a)
grad_loss = jax.grad(loss)
weights = np.random.randn(w_len,)
res = minimize(loss,weights,jac=grad_loss, options = {'maxiter':100})
But unfortunately, I get this error:
Traceback (most recent call last):
File "amplitudeencoding_jax.py", line 100, in <module>
res = train(1000)
File "amplitudeencoding_jax.py", line 84, in train
res = minimize(loss,weights,jac=grad_loss, options = {'maxiter':100})
File "/lhcbdata/miniconda3/envs/dzuliani_env/lib/python3.8/site- packages/scipy/optimize/_minimize.py", line 614, in minimize
return _minimize_bfgs(fun, x0, args, jac, callback, **options)
File "/lhcbdata/miniconda3/envs/dzuliani_env/lib/python3.8/site-packages/scipy/optimize/optimize.py", line 1135, in _minimize_bfgs
sf = _prepare_scalar_function(fun, x0, jac, args=args, epsilon=eps,
File "/lhcbdata/miniconda3/envs/dzuliani_env/lib/python3.8/site-packages/scipy/optimize/optimize.py", line 261, in _prepare_scalar_function
sf = ScalarFunction(fun, x0, args, grad, hess,
File "/lhcbdata/miniconda3/envs/dzuliani_env/lib/python3.8/site-packages/scipy/optimize/_differentiable_functions.py", line 136, in __init__
self._update_fun()
File "/lhcbdata/miniconda3/envs/dzuliani_env/lib/python3.8/site-packages/scipy/optimize/_differentiable_functions.py", line 226, in _update_fun
self._update_fun_impl()
File "/lhcbdata/miniconda3/envs/dzuliani_env/lib/python3.8/site-packages/scipy/optimize/_differentiable_functions.py", line 133, in update_fun
self.f = fun_wrapped(self.x)
File "/lhcbdata/miniconda3/envs/dzuliani_env/lib/python3.8/site-packages/scipy/optimize/_differentiable_functions.py", line 130, in fun_wrapped
return fun(x, *args)
File "amplitudeencoding_jax.py", line 64, in loss
out = pcircuit(datas,weights)
File "/lhcbdata/miniconda3/envs/dzuliani_env/lib/python3.8/site-packages/pennylane/tape/qnode.py", line 530, in __call__
self.construct(args, kwargs)
File "/lhcbdata/miniconda3/envs/dzuliani_env/lib/python3.8/site-packages/pennylane/tape/qnode.py", line 469, in construct
self.qfunc_output = self.func(*args, **kwargs)
File "amplitudeencoding_jax.py", line 44, in circuit
AmplitudeEmbedding(data,normalize=False,wires=range(n_qubits))
File "/lhcbdata/miniconda3/envs/dzuliani_env/lib/python3.8/site-packages/pennylane/templates/decorator.py", line 69, in wrapper
func(*args, **kwargs)
File "/lhcbdata/miniconda3/envs/dzuliani_env/lib/python3.8/site-packages/pennylane/templates/embeddings/amplitude.py", line 296, in AmplitudeEmbedding
features = _preprocess(features, wires, pad_with, normalize)
File "/lhcbdata/miniconda3/envs/dzuliani_env/lib/python3.8/site-packages/pennylane/templates/embeddings/amplitude.py", line 81, in _preprocess
if not qml.math.allclose(norm, 1.0, atol=TOLERANCE):
File "/lhcbdata/miniconda3/envs/dzuliani_env/lib/python3.8/site-packages/pennylane/math/fn.py", line 138, in allclose
return np.allclose(t1, t2, rtol=rtol, atol=atol, **kwargs)
File "<__array_function__ internals>", line 5, in allclose
File "/lhcbdata/miniconda3/envs/dzuliani_env/lib/python3.8/site-packages/numpy/core/numeric.py", line 2189, in allclose
res = all(isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan))
File "<__array_function__ internals>", line 5, in isclose
File "/lhcbdata/miniconda3/envs/dzuliani_env/lib/python3.8/site-packages/numpy/core/numeric.py", line 2278, in isclose
x = asanyarray(a)
File "/lhcbdata/miniconda3/envs/dzuliani_env/lib/python3.8/site-packages/numpy/core/_asarray.py", line 136, in asanyarray
return array(a, dtype, copy=False, order=order, subok=True)
jax._src.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float64[])>with<BatchTrace(level=1/2)> with
val = Traced<ShapedArray(float64[100])>with<DynamicJaxprTrace(level=0/2)>
batch_dim = 0
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
Does anybody know if it is possible to implement AmplitudeEmbedding in JAX? And if so, what can I do to solve this issue?
Thank you in advance!
Cheers,
Davide