Error faced in training the quantum network for estimating parameters

Hi @G_Akash , it seems your workflow is evaluating a lot of circuit with a relatively low qubit count. This is one of the regimes where Lightning underperforms the most. This is because creating every circuit in PennyLane incurs some overhead for decompositions, validation, etc. That overhead will become relatively cheap at some point (albeit still expensive in absolute terms), but it may only be in the 15-20 qubit range depending on several factors. To deal with that, the PennyLane collab has created Catalyst which allows jitting circuits to avoid all overheads. I could not make it work off the bat with your scripts, but you may want to try it in the future.

1 Like

Hi @G_Akash , I reached out to the Catalyst team and they provided a fix (the default “pass” was missing "func.func(hlo-legalize-shapeops-to-standard)") and example based on one of your workflows (see below). This is running PL, Lightning v0.34 and Catalyst 0.4.1 on my laptop. Here are some figures

Time to apply qnn (lightning) = 5.686362e-03
Time to compile/apply qnn = 1.556851e+00
Time to apply qnn = 3.990927e-04

So if you call a function enough times, as in an optimization problem, it is worth paying the compilation price to get a neat speed-up of 10x+ over using Lightning alone. I encourage you to have a try (normally you should not need the pipelines argument below and simply define qjit_circ = qjit(circuit)) and feed us back any problem you might face. Cheers!

import pennylane as qml
from pennylane import numpy as np
from catalyst import qjit
from catalyst.compiler import DEFAULT_PIPELINES

n_dset = 10
n_qubits = 3
layers = 2

CustomHLOLoweringPass = (
    "CustomHLOLoweringPass",
    [
        "canonicalize",
        "func.func(chlo-legalize-to-hlo)",
        "stablehlo-legalize-to-hlo",
        "func.func(mhlo-legalize-control-flow)",
        "func.func(hlo-legalize-shapeops-to-standard)",
        "func.func(hlo-legalize-to-linalg)",
        "func.func(mhlo-legalize-to-std)",
        "convert-to-signless",
        "canonicalize",
        "scatter-lowering",
        "hlo-custom-call-lowering",
        "cse",
    ],
)

DEFAULT_PIPELINES[0] = CustomHLOLoweringPass

dev = qml.device("lightning.qubit", wires=n_qubits)

@qml.qnode(dev, diff_method="adjoint")
def qnn(weights, inputs):
    qml.AmplitudeEmbedding(inputs, wires=range(n_qubits), pad_with=0.5)
    qml.BasicEntanglerLayers(weights, wires=range(n_qubits))
    for i in range(n_qubits - 1):
        if i % 2 == 0:
            qml.CNOT(wires=[i, i + 1])
    return [qml.expval(qml.PauliZ(i)) for i in range(n_qubits)[1::2]]

x_train = np.random.rand(n_dset, 2**n_qubits)
weights = np.random.rand(layers, n_qubits)

import time

t0 = time.time()
n_iter = 0
while time.time() - t0 < 1.0:
    n_iter += 1
    qnn(weights, x_train[0])
t1 = time.time() - t0
print(f"Time to apply qnn (lightning) = {t1 / n_iter:0.6e}")

qjit_qnn = qjit(qnn, pipelines=DEFAULT_PIPELINES)

t0 = time.time()
warmup = qjit_qnn(weights, x_train[0])
t1 = time.time() - t0
print(f"Time to compile/apply qnn = {t1:0.6e}")

t0 = time.time()
n_iter = 0
while time.time() - t0 < 1.0:
    n_iter += 1
    qjit_qnn(weights, x_train[0])
t1 = time.time() - t0
print(f"Time to apply qnn = {t1 / n_iter:0.6e}")
1 Like

Hey @Vincent_Michaud-Riou, good to hear from you.

I will check this out for my problem and will let you know. However, currently, I am using default.qubit with backdrop and JAX-JIT interface, and I consistently get a huge speed-up (about 5-7x I suppose) for a 7-qubit 25-layer structure where one layer consists of a qml.Rot, qml.CNOT and finally qml.U3 (so totally we have 1050 trainable weights). However, the caveat is that I have to spend some time on the compilation. I am not well-versed in jitting so I assume the compilation has to do with the concept of jitting.

Could you let me know qjit differs from JAX jit?
Also, if you could I need some help in figuring out a circuit architecture for my one of problems. Do let me know if it is possible. Thank you

Hi @G_Akash ,

I have not played that much with either jax.jit or Catalyst’s qjit myself, but let me try to answer your answers as best as I can. When I speak of “jitting” it means Just-In-Time compilation, so it has everything to do with compilation indeed. You could have Just-In-Time anything, but in this present context compilation is implied. Since you can actually trigger compilation ahead of time, it may be clearer to simply speak of compilation, but since jit is much in use I’ll use the terms interchangeably.

Could you let me know qjit differs from JAX jit?

Simply put jax.jit has a different scope than qjit, which stands for quantum JIT. jax.jit is a high-performance numerical library that supports auto-differentiation (i.e. backprop) targeting ML/AI applications. qjit compiles quantum circuit for fast execution on classical, quantum and hybrid backends. Since one rarely has access to quantum computers these days (reliable ones anyways), this usually means it will tie into one or several high-performance simulators, and there come the Lightning plugins.

At a lower level, where JAX will trace, analyze and compile the whole default.qubit workflow, heuristically mapping gate applications to optimized tensor contraction operations, etc. Catalyst will analyze the quantum circuit (or workflow) you would like to run and directly call Lightning or some other specialized high-perf quantum simulator. JAX is good, so it will likely do well as you have experienced, but it does not have the awareness of what you’re trying to do that Catalyst does. JAX also has to rely on generic linear algebra libraries to execute costly quantum gates, while Catalyst calls Lightning that has special routines to do so (and which are on average faster). Finally, JAX will require a lot of memory and time to compile workflows, in particular trace through everything for backprop, while Catalyst can use Lightning’s adjoint differentiation method, which does not need full tracing, consumes less memory (sometimes much less) and is faster (again on average, in the large system limits, etc.)

Also, if you could I need some help in figuring out a circuit architecture for my one of problems.

I would have to refer you to a colleague since this is outside my field of expertise. I would suggest creating another issue to task this. But if you’re wondering how to run something larger, faster, then do reach out this is my lawn :wink:

Hey @Vincent_Michaud-Riou, thanks for the explanation. I think I got some basic idea about it. To understand catalyst and qjit, I tried running some toy circuits but I got some weird error and I don’t understand where I went wrong

@qml.qnode(qml.device("lightning.qubit", wires=2))
def circuit(theta):
    qml.Hadamard(wires=0)
    qml.RX(theta, wires=1)
    qml.CNOT(wires=[0,1])
    return qml.expval(qml.PauliZ(wires=1))

jitted_circuit = qjit(circuit)
jitted_circuit(0.7)

This is the error I received. Kindly let me know what this error is. Btw I am using Pennylane v0.35 and Catalyst v0.5.

----> 1 jitted_circuit(0.7)

File ~/anaconda3/envs/pycat/lib/python3.11/site-packages/catalyst/jit.py:108, in QJIT.__call__(self, *args, **kwargs)
    105 if EvaluationContext.is_tracing():
    106     return self.user_function(*args, **kwargs)
--> 108 requires_promotion = self.jit_compile(args)
    110 # If we receive tracers as input, dispatch to the JAX integration.
    111 if any(isinstance(arg, jax.core.Tracer) for arg in tree_flatten(args)[0]):

File ~/anaconda3/envs/pycat/lib/python3.11/site-packages/catalyst/jit.py:169, in QJIT.jit_compile(self, args)
    167     self.jaxpr, self.out_treedef, self.c_sig = self.capture(args)
    168     self.mlir_module, self.mlir = self.generate_ir()
--> 169     self.compiled_function, self.qir = self.compile()
    171     self.fn_cache.insert(self.compiled_function, args, self.out_treedef, self.workspace)
    173 elif self.compiled_function is not cached_fn.compiled_fn:
    174     # Restore active state from cache.

File ~/anaconda3/envs/pycat/lib/python3.11/site-packages/catalyst/jit.py:271, in QJIT.compile(self)
    267 # The function name out of MLIR has quotes around it, which we need to remove.
    268 # The MLIR function name is actually a derived type from string which has no
    269 # `replace` method, so we need to get a regular Python string out of it.
    270 func_name = str(self.mlir_module.body.operations[0].name).replace('"', "")
--> 271 shared_object, llvm_ir, _ = self.compiler.run(self.mlir_module, self.workspace)
    273 shared_object, llvm_ir, _ = self.compiler.run(self.mlir_module, self.workspace)
    274 compiled_fn = CompiledFunction(shared_object, func_name, restype, self.compile_options)

File ~/anaconda3/envs/pycat/lib/python3.11/site-packages/catalyst/compiler.py:504, in Compiler.run(self, mlir_module, *args, **kwargs)
    489 def run(self, mlir_module, *args, **kwargs):
    490     """Compile an MLIR module to a shared object.
    491 
    492     .. note::
   (...)
    501         (str): filename of shared object
    502     """
--> 504     return self.run_from_ir(
    505         mlir_module.operation.get_asm(
    506             binary=False, print_generic_op_form=False, assume_verified=True
    507         ),
    508         str(mlir_module.operation.attributes["sym_name"]).replace('"', ""),
    509         *args,
    510         **kwargs,
    511     )

File ~/anaconda3/envs/pycat/lib/python3.11/site-packages/catalyst/compiler.py:481, in Compiler.run_from_ir(self, ir, module_name, workspace)
    478 ret_type_name = compiler_output.get_function_attributes().get_return_type()
    480 if lower_to_llvm:
--> 481     output = LinkerDriver.run(filename, options=self.options)
    482     output_filename = str(pathlib.Path(output).absolute())
    483 else:

File ~/anaconda3/envs/pycat/lib/python3.11/site-packages/catalyst/compiler.py:410, in LinkerDriver.run(infile, outfile, flags, fallback_compilers, options)
    408     options = CompileOptions()
    409 if flags is None:
--> 410     flags = LinkerDriver.get_default_flags(options)
    411 if fallback_compilers is None:
    412     fallback_compilers = LinkerDriver._default_fallback_compilers

File ~/anaconda3/envs/pycat/lib/python3.11/site-packages/catalyst/compiler.py:302, in LinkerDriver.get_default_flags(options)
    300 file_prefix = "libopenblas"
    301 search_pattern = path.join(scipy_lib_path, f"{file_prefix}*{file_extension}")
--> 302 openblas_so_file = glob.glob(search_pattern)[0]
    303 openblas_lib_name = path.basename(openblas_so_file)[3 : -len(file_extension)]
    305 lib_path_flags += [
    306     f"-Wl,-rpath,{scipy_lib_path}",
    307     f"-L{scipy_lib_path}",
    308 ]

IndexError: list index out of range

Edit: I tried with Pennylane v0.34 and Catalyst v0.4.1 and the error remained the same.

Hi @G_Akash , to be sure, could you run the script below (which works for me) and send me the output? This way we’ll make sure we are using the same Python version, etc.

import pennylane as qml
from catalyst import qjit

qml.about()

@qml.qnode(qml.device("lightning.qubit", wires=2))
def circuit(theta):
    qml.Hadamard(wires=0)
    qml.RX(theta, wires=1)
    qml.CNOT(wires=[0,1])
    return qml.expval(qml.PauliZ(wires=1))

jitted_circuit = qjit(circuit)
jitted_circuit(0.7)

I get

qml.about
Name: PennyLane
Version: 0.35.0
Summary: PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
Home-page: https://github.com/PennyLaneAI/pennylane
Author: 
Author-email: 
License: Apache License 2.0
Location: /tmp/venv/lib/python3.10/site-packages
Requires: appdirs, autograd, autoray, cachetools, networkx, numpy, pennylane-lightning, requests, rustworkx, scipy, semantic-version, toml, typing-extensions
Required-by: PennyLane-Catalyst, PennyLane_Lightning

Platform info:           Linux-5.15.0-100-generic-x86_64-with-glibc2.31
Python version:          3.10.13
Numpy version:           1.26.4
Scipy version:           1.12.0
Installed devices:
- default.clifford (PennyLane-0.35.0)
- default.gaussian (PennyLane-0.35.0)
- default.mixed (PennyLane-0.35.0)
- default.qubit (PennyLane-0.35.0)
- default.qubit.autograd (PennyLane-0.35.0)
- default.qubit.jax (PennyLane-0.35.0)
- default.qubit.legacy (PennyLane-0.35.0)
- default.qubit.tf (PennyLane-0.35.0)
- default.qubit.torch (PennyLane-0.35.0)
- default.qutrit (PennyLane-0.35.0)
- null.qubit (PennyLane-0.35.0)
- lightning.qubit (PennyLane_Lightning-0.35.0)
- nvidia.custatevec (PennyLane-Catalyst-0.5.0)
- nvidia.cutensornet (PennyLane-Catalyst-0.5.0)
- softwareq.qpp (PennyLane-Catalyst-0.5.0)

And while I’m at it, here is the result of pip freeze

pip freeze
appdirs==1.4.4
astunparse==1.6.3
autograd==1.6.2
autoray==0.6.9
cachetools==5.3.3
certifi==2024.2.2
charset-normalizer==3.3.2
diastatic-malt==2.15.1
future==1.0.0
gast==0.5.4
idna==3.6
jax==0.4.23
jaxlib==0.4.23
ml-dtypes==0.3.2
networkx==3.2.1
numpy==1.26.4
opt-einsum==3.3.0
PennyLane==0.35.0
PennyLane-Catalyst==0.5.0
PennyLane_Lightning==0.35.0
requests==2.31.0
rustworkx==0.14.1
scipy==1.12.0
semantic-version==2.10.0
six==1.16.0
termcolor==2.4.0
toml==0.10.2
tomlkit==0.12.4
typing_extensions==4.10.0
urllib3==2.2.1

Hey @Vincent_Michaud-Riou, thanks for the help. My scipy library was not up to date which caused the problem. Now my code is running without any issues.

I have a follow-up question. If I have access to Nvidia cuquantum and GPU, how do I integrate that with pennylane-catalyst? If possible, please guide me through it. And I think lightning.gpu does not work with Catalyst, right?

Also like JAX-JIT can I accelerate catalyst using GPU?

Nice to hear that @G_Akash . NVIDIA’s cuQuantum is available through Lightning-GPU which is not integrated with Catalyst yet. But as we’ve shown in a recent blog post, we have Lightning-Kokkos which can hold its own against cuQuantum on NVIDIA cards. Catalyst’s GPU-support is experimental at the moment, so you may be able to use NVIDIA cards via Lightning-Kokkos if you build Catalyst from source. If you are prepared for a bumpy compilation ride, I could give you some more indications. Otherwise, you may wait as this is something we plan to work on next quarter.

I will look into it. So I tried using Catalyst for my problem but I ran into a problem which is the overall runtime. To give you a context, jax-jitting takes about 2.5mins (with a compilation time of 75s) to train the model for a specific dataset but qjit was taking over 18 mins so I had to kill it. I don’t understand what the issue is. I can share my code and explain what my dataset is, if required

If you don’t mind, do share your code and we’ll see if jax.jit is simply hard to beat or the discrepancy is from a misuse of qjit. One thing I’ll mention is that unlike some other jitters, qjit does not have an LRU cache mechanism. So one must make sure to pass jitted_circuit around to keep the function object within scope, otherwise the compilation will be triggered over and over.

These are my initial parameters/conditions,

n_qubits = int(np.log2(X_21.shape[1]))
n_layer = 30
opt = optax.adam(learning_rate=0.05)
params = params_init(n_layer,n_qubits,Y_21)
opt_state = opt.init(params)
x_train,x_val,x_test,y_train,y_val,y_test = data_split(X_data_nf,Y,0.2)

So, these are my common functions,

def circ(weights,inputs):
    qml.AmplitudeEmbedding(features=inputs, wires=range(n_qubits),normalize=True,pad_with=0.5)
    StronglyEntanglingLayers(weights, wires=range(n_qubits),imprimitive=qml.ops.CNOT)
    out =  [qml.expval(qml.PauliZ(i)+qml.PauliZ(i+1)) for i in range(n_qubits-1)[0::2]]
    return out

def qnn(params,inputs):
    weights = params["weights"]
    bias = params["bias"]
    circ_out = circ(weights,inputs)
    out = []
    for i in range(len(circ_out)):
        out.append(circ_out[i] + bias[i])
    return out 

def mse(observed,predictions):
  loss = jnp.sum((observed - predictions) ** 2 / len(observed))
  return jnp.mean(loss)

def predict(params,features):
    preds = qnn_qjit(params,features)
    preds_np = jnp.asarray(preds).T
    return preds_np

batched_predict = jax.vmap(predict, in_axes=(None, 0))

def cost(params, features, observed):
    preds = batched_predict(params,features)
    cost = mse(observed, preds)
    return cost

@jax.jit
def jit_cost(params, features, observed):
    preds = batched_predict(params,features)
    cost = mse(observed, preds)
    return cost

@jax.jit
def update_step(params, opt_state, features, observed):
    train_cost, grads = jax.value_and_grad(jit_cost)(params, features, observed)
    updates, opt_state = opt.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state, train_cost

def fit(params,opt_state,x_train,y_train,x_val,y_val,epoch,batch_size):    
    train_loss = []
    val_loss = []
    num_train = len(x_train)
    num_val =len(x_val)
    key_t = random.PRNGKey(np.random.randint(0,1e4))
    for i in range(epoch):    
        key_t, key_v = random.split(key_t)
        idx_train = random.choice(key_t,num_train,shape=(batch_size,))
        idx_val = random.choice(key_v,num_val,shape=(batch_size,))
        x_train_batch = jnp.asarray(x_train[idx_train])
        y_train_batch = jnp.asarray(y_train[idx_train])
        x_val_batch = jnp.asarray(x_val)
        y_val_batch = jnp.asarray(y_val)
        start_time = time.time()
        params, opt_state, train_cost = update_step(params, opt_state, x_train_batch, y_train_batch)
        end_time = time.time()
        val_cost = jit_cost(params,x_val_batch,y_val_batch)
        epoch_time = end_time - start_time
        print("Epoch: {:5d} | Loss: {:0.7f} | Val_Loss: {:0.7f} | Time: {:0.4f} seconds".format(i+1, train_cost,val_cost,epoch_time))

        train_loss.append(train_cost)
        val_loss.append(val_cost)
    return params, train_loss,val_loss  

I run the below to train the circuit,

params,train_loss,val_loss = fit(params,opt_state,x_train,y_train,x_val,y_val,300,512)

For only JAX-JIT, I use,

dev = qml.device("default.qubit", wires=n_qubits)
@qml.qnode(dev,interface="jax",diff_method='backprop')  #I use this along with circ function

For catalyst + JAX, I use, (and yes I use JAX-JIT on the cost function and update step function)

dev = qml.device("lightning.qubit", wires=n_qubits)
@qml.qnode(dev,diff_method='adjoint')

For your reference, this is what my input is, and the corresponding output is the parameters defining each curve and there are totally 10000 curves sampled at 1024 points (10000,1024), after which I use an autoencoder to reduce it to (10000,128). So my goal is to estimate the parameters of these said curves.
21cm_signal

Hi @G_Akash thanks for trying Catalyst!

I see you shared your code when using @jax.jit. Would it be correct to say that the same function decorated with @jax.jit would also be JITed using Catalyst?

In my experience looking at compile time issues in Catalyst, one of the main issues is the lack of using catalyst control flow operations. I would recommend looking into either the control flow operators and using them manually or alternatively using the autograph feature.

See here for some details: AutoGraph guide — Catalyst 0.5.0 documentation

I would be interested in seeing exactly the program that took more than 18 minutes to compile / run exactly as coded. That way we can make sure that this won’t happen again in the future. Thanks!