Trouble calling QJIT compiled QNode inside a nnx.jit compiled flax.nnx.Module

Hi,

I’m trying to “package” a Catalyst (qjit) compiled circuit within a flax.nnx.Module for easy management of weights.

However, the tracer passed by NNX is rejected by the QJIT function.

Error:

TypeError: Value Traced<ShapedArray(float64[3])>with<DynamicJaxprTrace(level=1/0)> with type <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'> is not a valid JAX type

Sample code (reproduces above error):

from functools import partial
import jax.random
import jax.numpy as jnp
from flax import nnx
import pennylane as qml
from pennylane import QNode, qjit

jax.config.update("jax_enable_x64", False)


def make_pqc_circuit(n_input: int) -> QNode:
    device = qml.device("lightning.qubit", wires=n_input)

    def circuit(inputs, weights):
        qml.AngleEmbedding(features=inputs, wires=range(n_input), rotation="X")
        for i in range(1, n_input):
            qml.CRX(weights[i - 1], wires=[i, 0])
        return qml.expval(qml.PauliZ(wires=0))

    return QNode(circuit, device, "jax")


class PQCModule(nnx.Module):
    def __init__(self, rngs: nnx.Rngs, n_input):
        self.circuit = qjit(make_pqc_circuit(n_input))
        self.weights = nnx.Param(jax.random.uniform(rngs.params(), shape=(n_input - 1,), minval=0, maxval=jnp.pi))

    def __call__(self, x):
        return self.circuit(x, self.weights)


@nnx.jit
@partial(nnx.vmap, in_axes=(None, 0), out_axes=0)
def forward_pass(module, x):
    return module(x)


if __name__ == '__main__':
    inputs = jax.random.uniform(jax.random.key(0), shape=(16, 3))
    module = PQCModule(nnx.Rngs(0), n_input=4)
    outputs = forward_pass(module, inputs)
    print(outputs.shape)

I couldn’t find much mention of Flax in the documentation. Is it not supported?
It would be amazing to have Flax module support.

Thanks.

Setup Details:
WSL2 Ubuntu on Windows 10

qml.about()

Name: PennyLane
Version: 0.39.0
Summary: PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
Author:
Author-email:
License: Apache License 2.0
Location: /home//.local/lib/python3.11/site-packages
Requires: appdirs, autograd, autoray, cachetools, networkx, numpy, packaging, pennylane-lightning, requests, rustworkx, scipy, toml, typing-extensions
Required-by: PennyLane-Catalyst, PennyLane_Lightning, PennyLane_Lightning_GPU

Platform info: Linux-5.15.167.4-microsoft-standard-WSL2-x86_64-with-glibc2.35
Python version: 3.11.9
Numpy version: 1.26.4
Scipy version: 1.13.0
Installed devices:

  • lightning.qubit (PennyLane_Lightning-0.39.0)
  • nvidia.custatevec (PennyLane-Catalyst-0.9.0)
  • nvidia.cutensornet (PennyLane-Catalyst-0.9.0)
  • oqc.cloud (PennyLane-Catalyst-0.9.0)
  • softwareq.qpp (PennyLane-Catalyst-0.9.0)
  • default.clifford (PennyLane-0.39.0)
  • default.gaussian (PennyLane-0.39.0)
  • default.mixed (PennyLane-0.39.0)
  • default.qubit (PennyLane-0.39.0)
  • default.qutrit (PennyLane-0.39.0)
  • default.qutrit.mixed (PennyLane-0.39.0)
  • default.tensor (PennyLane-0.39.0)
  • null.qubit (PennyLane-0.39.0)
  • reference.qubit (PennyLane-0.39.0)
  • lightning.gpu (PennyLane_Lightning_GPU-0.39.0)
pip show pennylane pennylane-catalyst jax flax

Name: PennyLane
Version: 0.39.0
Summary: PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
Author:
Author-email:
License: Apache License 2.0
Location: /home//.local/lib/python3.11/site-packages
Requires: appdirs, autograd, autoray, cachetools, networkx, numpy, packaging, pennylane-lightning, requests, rustworkx, scipy, toml, typing-extensions
Required-by: PennyLane-Catalyst, PennyLane_Lightning, PennyLane_Lightning_GPU

Name: PennyLane-Catalyst
Version: 0.9.0
Summary: A JIT compiler for hybrid quantum programs in PennyLane
Author:
Author-email:
License: Apache License 2.0
Location: /home//.local/lib/python3.11/site-packages
Requires: diastatic-malt, jax, jaxlib, numpy, pennylane, pennylane-lightning, scipy
Required-by:

Name: jax
Version: 0.4.28
Summary: Differentiate, compile, and transform Numpy code.
Author: JAX team
Author-email: jax-dev@google.com
License: Apache-2.0
Location: /home//.local/lib/python3.11/site-packages
Requires: ml-dtypes, numpy, opt-einsum, scipy
Required-by: chex, flax, optax, orbax-checkpoint, PennyLane-Catalyst

Name: flax
Version: 0.10.2
Summary: Flax: A neural network library for JAX designed for flexibility
Author:
Author-email: Flax team flax-dev@google.com
License:
Location: /home//.local/lib/python3.11/site-packages
Requires: jax, msgpack, numpy, optax, orbax-checkpoint, PyYAML, rich, tensorstore, typing-extensions
Required-by:

Hi @LasradoRohan, thanks for sharing your use case! I haven’t personally tried flax, but since it’s a JAX-based library I would think that it should be supported in principle. Allow me some time to look into your example :slight_smile:

Hi @LasradoRohan, is the call method supposed to read as follows instead?

    def __call__(self, x):
        return self.circuit(x, self.weights.value)

It seems that JAX didn’t like receiving the nnx.Param object into the Catalyst function we register with JAX, but if I pass the JAX array instead (self.weights.value) it seems to work.

I’m really not familiar with flax which is why this is just a guess, but when I opened their docs the first example is already using this .value pattern.

Can confirm that extracting the Jax array from nnx.Param works.
Thank you for the solution! Apologies for the oversight on my part.

Glad we could get your program working :slight_smile: