Problem when using JAX with AmplitudeEmbedding

Dear experts,

I’m trying to set up a circuit with AmplitudeEmbedding + StronglyEntanglingLayers using JAX. I’m using 4 input variables, therefore I’m playing with 2 qubits in the following way:


dev = qml.device("default.qubit.jax",wires=n_qubits)

l = 1
w_shape = (l,n_qubits,3)

def circuit(data,weights):
    se_weights = weights[:w_len].reshape(w_shape)
    return qml.expval(qml.PauliZ(0))

vcircuit = jax.vmap(circuit,in_axes=(0,None))
n_cores = jax.local_device_count()
pcircuit = jax.pmap(vcircuit,in_axes=(0,None))

def train(n_events):

    datas, labels = get_muon_dataset(n_jets=n_events,balanced=True)
    datas = jnp.stack(np.array_split(datas,n_cores))
    labels = jnp.stack(np.array_split(labels,n_cores))

    def loss(weights):
        out = pcircuit(datas,weights)
        l = jnp.mean((out-labels)**2)
        return l

    def accuracy(weights):
        out = pcircuit(datas,weights)
        return jnp.mean(jnp.sign(out) == jnp.sign(labels))

    def callbacks(weights):
        l = loss(weights)
        a = accuracy(weights)
        print("loss: ", l, "  acc: ", a)

grad_loss = jax.grad(loss)
weights = np.random.randn(w_len,)
res = minimize(loss,weights,jac=grad_loss, options = {'maxiter':100})

But unfortunately, I get this error:

  Traceback (most recent call last):
  File "", line 100, in <module>
  res = train(1000)
  File "", line 84, in train
  res = minimize(loss,weights,jac=grad_loss, options = {'maxiter':100})
  File "/lhcbdata/miniconda3/envs/dzuliani_env/lib/python3.8/site-            packages/scipy/optimize/", line 614, in minimize
  return _minimize_bfgs(fun, x0, args, jac, callback, **options)
  File "/lhcbdata/miniconda3/envs/dzuliani_env/lib/python3.8/site-packages/scipy/optimize/", line 1135, in _minimize_bfgs
  sf = _prepare_scalar_function(fun, x0, jac, args=args, epsilon=eps,
  File "/lhcbdata/miniconda3/envs/dzuliani_env/lib/python3.8/site-packages/scipy/optimize/", line 261, in _prepare_scalar_function
  sf = ScalarFunction(fun, x0, args, grad, hess,
  File "/lhcbdata/miniconda3/envs/dzuliani_env/lib/python3.8/site-packages/scipy/optimize/", line 136, in __init__
  File "/lhcbdata/miniconda3/envs/dzuliani_env/lib/python3.8/site-packages/scipy/optimize/", line 226, in _update_fun
  File "/lhcbdata/miniconda3/envs/dzuliani_env/lib/python3.8/site-packages/scipy/optimize/", line 133, in update_fun
  self.f = fun_wrapped(self.x)
  File "/lhcbdata/miniconda3/envs/dzuliani_env/lib/python3.8/site-packages/scipy/optimize/", line 130, in fun_wrapped
  return fun(x, *args)
  File "", line 64, in loss
  out = pcircuit(datas,weights)
  File "/lhcbdata/miniconda3/envs/dzuliani_env/lib/python3.8/site-packages/pennylane/tape/", line 530, in __call__
  self.construct(args, kwargs)
  File "/lhcbdata/miniconda3/envs/dzuliani_env/lib/python3.8/site-packages/pennylane/tape/", line 469, in construct
  self.qfunc_output = self.func(*args, **kwargs)
  File "", line 44, in circuit
  File "/lhcbdata/miniconda3/envs/dzuliani_env/lib/python3.8/site-packages/pennylane/templates/", line 69, in wrapper
  func(*args, **kwargs)
  File "/lhcbdata/miniconda3/envs/dzuliani_env/lib/python3.8/site-packages/pennylane/templates/embeddings/", line 296, in AmplitudeEmbedding
  features = _preprocess(features, wires, pad_with, normalize)
  File "/lhcbdata/miniconda3/envs/dzuliani_env/lib/python3.8/site-packages/pennylane/templates/embeddings/", line 81, in _preprocess
  if not qml.math.allclose(norm, 1.0, atol=TOLERANCE):
  File "/lhcbdata/miniconda3/envs/dzuliani_env/lib/python3.8/site-packages/pennylane/math/", line 138, in allclose
  return np.allclose(t1, t2, rtol=rtol, atol=atol, **kwargs)
  File "<__array_function__ internals>", line 5, in allclose
  File "/lhcbdata/miniconda3/envs/dzuliani_env/lib/python3.8/site-packages/numpy/core/", line 2189, in allclose
  res = all(isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan))
  File "<__array_function__ internals>", line 5, in isclose
  File "/lhcbdata/miniconda3/envs/dzuliani_env/lib/python3.8/site-packages/numpy/core/", line 2278, in isclose
  x = asanyarray(a)
  File "/lhcbdata/miniconda3/envs/dzuliani_env/lib/python3.8/site-packages/numpy/core/", line 136, in asanyarray
  return array(a, dtype, copy=False, order=order, subok=True)
  jax._src.errors.TracerArrayConversionError: The numpy.ndarray       conversion method __array__() was called on the JAX Tracer object       Traced<ShapedArray(float64[])>with<BatchTrace(level=1/2)> with
    val = Traced<ShapedArray(float64[100])>with<DynamicJaxprTrace(level=0/2)>
    batch_dim = 0

Does anybody know if it is possible to implement AmplitudeEmbedding in JAX? And if so, what can I do to solve this issue?

Thank you in advance!


Hi @Davide_Zuliani! It appears that the internal check to ensure that the amplitude is normalized,

qml.math.allclose(norm, 1.0, atol=TOLERANCE)

does not support @jax.jit. Unfortunately, this check is currently hardcoded, so the only way around it with the current PL version is to modify the AmplitudeEmbedding template code to remove the check, as so:

diff --git a/pennylane/templates/embeddings/ b/pennylane/templates/embeddings/
index e0200cf8..7a0682cb 100644
--- a/pennylane/templates/embeddings/
+++ b/pennylane/templates/embeddings/
@@ -188,16 +188,9 @@ class AmplitudeEmbedding(Operation):
                 feature_set = qml.math.concatenate([feature_set, padding], axis=0)

             # normalize
-            norm = qml.math.sum(qml.math.abs(feature_set) ** 2)
-            if not qml.math.allclose(norm, 1.0, atol=TOLERANCE):
-                if normalize or pad_with:
-                    feature_set = feature_set / np.sqrt(norm)
-                else:
-                    raise ValueError(
-                        f"Features must be a vector of norm 1.0; got norm {norm}."
-                        "Use 'normalize=True' to automatically normalize."
-                    )
+            if normalize:
+                norm = qml.math.sum(qml.math.abs(feature_set) ** 2)
+                feature_set = feature_set / qml.math.sqrt(norm)

             features_batch[i] = qml.math.cast(feature_set, np.complex128)

Once you do this, the JIT will work:

dev = qml.device("default.qubit", wires=2)

def circuit(data, weights):
    qml.templates.AmplitudeEmbedding(data, normalize=True, wires=[0, 1])
    qml.templates.StronglyEntanglingLayers(weights, wires=[0, 1])
    return qml.expval(qml.PauliZ(0))

data = jnp.ones([4], dtype=jnp.float32)
weights = jnp.ones([2, 2, 3], dtype=jnp.float32)
>>> circuit(data, weights)
DeviceArray(0.13645427, dtype=float32)
>>> jax.grad(circuit)(data, weights)
DeviceArray([-0.11450433, -0.45287138,  0.24969524,  0.31768045], dtype=float32)

This is definitely feedback we will take back and consider, however — we want to find a balance between

  1. Ensuring all PennyLane code supports the JIT, and
  2. Providing useful validation checks!

So we will investigate and see if it is possible to retain the validation on state normalization, while also supporting JIT workflows.

1 Like

@Davide_Zuliani a small update: I have created a PR to enable jax.jit support with AmplitudeEmbedding here:

Before this gets merged in, you could install PennyLane directly from this branch to test-drive it :slight_smile:

pip install git+
1 Like

Hi @josh!
Thank you very much for your quick answer, very much appreciated!

Indeed now AmplitudeEmbedding is working with Jax, I tested with some code similar to your example.
Unfortunately, I still have some troubles when trying to parallelize computation on several CPU cores.

Here I post my code (an easy example to run on 5 CPU cores):

import os
import re
import time
import argparse
def set_host_device_count(n):
    xla_flags = os.getenv('XLA_FLAGS', '').lstrip('--')
    xla_flags = re.sub(r'xla_force_host_platform_device_count=.+\s', '', xla_flags).split()
    os.environ['XLA_FLAGS'] = ' '.join(['--xla_force_host_platform_device_count={}'.format(n)]
                                   + xla_flags)
import pennylane as qml
from pennylane import wires
from pennylane.templates import AngleEmbedding, StronglyEntanglingLayers, QAOAEmbedding, AmplitudeEmbedding

from jax.config import config
from dataset_utils import get_muon_dataset
import jax
import jax.numpy as jnp
import numpy as np
import pandas as pd
from scipy.optimize import minimize
import matplotlib.pyplot as plt

from numpy.random import randn


dev = qml.device("default.qubit", wires=n_qubits)

l = 1

    def circuit(data, weights):
    qml.templates.AmplitudeEmbedding(data, wires=range(n_qubits))
    qml.templates.StronglyEntanglingLayers(weights, wires=range(n_qubits))
    return qml.expval(qml.PauliZ(0))

vcircuit = jax.vmap(circuit,in_axes=(0,None))
n_cores = jax.local_device_count()
pcircuit = jax.pmap(vcircuit,in_axes=(0,None))

datas = jnp.stack([data] * n_cores)
datas = jnp.stack(np.array_split(datas,n_cores))

weights = jnp.ones([l,n_qubits,3], dtype=jnp.float32)

out = pcircuit(datas,weights)

The error that pops out is something like this:

Traceback (most recent call last): File "", line 75, in <module> out = pcircuit(datas,weights) File "/lhcbdata/miniconda3/envs/dzuliani_env/lib/python3.8/site-packages/pennylane/", line 674, in __call__ self.construct(args, kwargs) File "/lhcbdata/miniconda3/envs/dzuliani_env/lib/python3.8/site-packages/pennylane/", line 582, in construct self.qfunc_output = self.func(*args, **kwargs) File "", line 41, in circuit qml.templates.AmplitudeEmbedding(data, wires=range(n_qubits)) File "/lhcbdata/miniconda3/envs/dzuliani_env/lib/python3.8/site-packages/pennylane/templates/embeddings/", line 132, in __init__ features = self._preprocess(features, wires, pad_with, normalize) File "/lhcbdata/miniconda3/envs/dzuliani_env/lib/python3.8/site-packages/pennylane/templates/embeddings/", line 197, in _preprocess elif not qml.math.allclose(norm, 1.0, atol=TOLERANCE): jax._src.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[])>with<BatchTrace(level=1/2)> with val = Traced<ShapedArray(bool[1])>with<DynamicJaxprTrace(level=0/2)> batch_dim = 0 The problem arose with the boolfunction.

With AngleEmbedding everything works fine.

Thank you very much for your help :slight_smile:

Oh, nice catch @Davide_Zuliani! Just to clarify, this is when you run the code snippet on the abstract-support branch of PennyLane?

I think you may have caught an edge case, I will see if we can get this working :slight_smile:

A quick update – I believe I have fixed this issue in the branch now. You may need to run

pip uninstall pennylane
pip install git+

for this change to take effect, if you have installed this branch using pip (otherwise, if you used git, simply pull the latest changes!).

Let me know if this now works for you :slight_smile:

Hi @josh!
Thanks for the tips!

I actually managed to make it work by doing the following changes:

in file ../pennylane/templates/embeddings/ I changed line 197
elif not qml.math.allclose(norm, 1.0, atol=TOLERANCE):
elif not qml.math.is_abstract(norm):

while in file ../pennylane/devices/ I’ve commented lines 636-638:

#if not qml.math.is_abstract(state):
#    if not qml.math.allclose(qml.math.linalg.norm(state, ord=2), 1.0, atol=tolerance):
#        raise ValueError("Sum of amplitudes-squared does not equal one.")

Now, I think I’ve done just some harsh changes (maybe not understanding the real issue, although I think that qml.math.allclose is giving me some problems) but at least now it’s working fine.

I will also try your tips and see if it’s working, you’re approach seems definitely more “orthodox”.

Again thank you very much for your help :slight_smile:

Hi @Davide_Zuliani, thank you for sharing your solution too!

Enjoy using PennyLane :smile: