I’m trying to “package” a Catalyst (qjit) compiled circuit within a flax.nnx.Module for easy management of weights.
However, the tracer passed by NNX is rejected by the QJIT function.
TypeError: Value Traced<ShapedArray(float64[3])>with<DynamicJaxprTrace(level=1/0)> with type <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'> is not a valid JAX type
Sample code (reproduces above error):
from functools import partial
import jax.random
import jax.numpy as jnp
from flax import nnx
import pennylane as qml
from pennylane import QNode, qjit
jax.config.update("jax_enable_x64", False)
def make_pqc_circuit(n_input: int) -> QNode:
device = qml.device("lightning.qubit", wires=n_input)
def circuit(inputs, weights):
qml.AngleEmbedding(features=inputs, wires=range(n_input), rotation="X")
for i in range(1, n_input):
qml.CRX(weights[i - 1], wires=[i, 0])
return qml.expval(qml.PauliZ(wires=0))
return QNode(circuit, device, "jax")
class PQCModule(nnx.Module):
def __init__(self, rngs: nnx.Rngs, n_input):
self.circuit = qjit(make_pqc_circuit(n_input))
self.weights = nnx.Param(jax.random.uniform(rngs.params(), shape=(n_input - 1,), minval=0, maxval=jnp.pi))
def __call__(self, x):
return self.circuit(x, self.weights)
@partial(nnx.vmap, in_axes=(None, 0), out_axes=0)
def forward_pass(module, x):
return module(x)
if __name__ == '__main__':
inputs = jax.random.uniform(jax.random.key(0), shape=(16, 3))
module = PQCModule(nnx.Rngs(0), n_input=4)
outputs = forward_pass(module, inputs)
I couldn’t find much mention of Flax in the documentation. Is it not supported?
It would be amazing to have Flax module support.
Setup Details:
WSL2 Ubuntu on Windows 10
Name: PennyLane
Version: 0.39.0
Summary: PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
License: Apache License 2.0
Location: /home//.local/lib/python3.11/site-packages
Requires: appdirs, autograd, autoray, cachetools, networkx, numpy, packaging, pennylane-lightning, requests, rustworkx, scipy, toml, typing-extensions
Required-by: PennyLane-Catalyst, PennyLane_Lightning, PennyLane_Lightning_GPU
Platform info: Linux-
Python version: 3.11.9
Numpy version: 1.26.4
Scipy version: 1.13.0
Installed devices:
- lightning.qubit (PennyLane_Lightning-0.39.0)
- nvidia.custatevec (PennyLane-Catalyst-0.9.0)
- nvidia.cutensornet (PennyLane-Catalyst-0.9.0)
- oqc.cloud (PennyLane-Catalyst-0.9.0)
- softwareq.qpp (PennyLane-Catalyst-0.9.0)
- default.clifford (PennyLane-0.39.0)
- default.gaussian (PennyLane-0.39.0)
- default.mixed (PennyLane-0.39.0)
- default.qubit (PennyLane-0.39.0)
- default.qutrit (PennyLane-0.39.0)
- default.qutrit.mixed (PennyLane-0.39.0)
- default.tensor (PennyLane-0.39.0)
- null.qubit (PennyLane-0.39.0)
- reference.qubit (PennyLane-0.39.0)
- lightning.gpu (PennyLane_Lightning_GPU-0.39.0)
pip show pennylane pennylane-catalyst jax flax
Name: PennyLane
Version: 0.39.0
Summary: PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
License: Apache License 2.0
Location: /home//.local/lib/python3.11/site-packages
Requires: appdirs, autograd, autoray, cachetools, networkx, numpy, packaging, pennylane-lightning, requests, rustworkx, scipy, toml, typing-extensions
Required-by: PennyLane-Catalyst, PennyLane_Lightning, PennyLane_Lightning_GPU
Name: PennyLane-Catalyst
Version: 0.9.0
Summary: A JIT compiler for hybrid quantum programs in PennyLane
License: Apache License 2.0
Location: /home//.local/lib/python3.11/site-packages
Requires: diastatic-malt, jax, jaxlib, numpy, pennylane, pennylane-lightning, scipy
Name: jax
Version: 0.4.28
Summary: Differentiate, compile, and transform Numpy code.
Author: JAX team
Author-email: jax-dev@google.com
License: Apache-2.0
Location: /home//.local/lib/python3.11/site-packages
Requires: ml-dtypes, numpy, opt-einsum, scipy
Required-by: chex, flax, optax, orbax-checkpoint, PennyLane-Catalyst
Name: flax
Version: 0.10.2
Summary: Flax: A neural network library for JAX designed for flexibility
Author-email: Flax team flax-dev@google.com
Location: /home//.local/lib/python3.11/site-packages
Requires: jax, msgpack, numpy, optax, orbax-checkpoint, PyYAML, rich, tensorstore, typing-extensions