How to make certain parameters "untrainable" with JAX?

Hello, I have a QNode with I’m training using the JAX interface. However, this QNode also implements parameters that are just parameters for my model and shouldn’t be trained. How do I specify this when optimizing my QNode?

Thanks!

Hey @NickGut0711,

Make sure that you’re using JAXopt or Optax to optimize your JAX-interfaced circuit (see here for more details: JAX interface — PennyLane 0.34.0 documentation). If you’re doing that, then you can specify which arguments get trained with the argnums keyword argument. E.g.,

res = jax.grad(circuit, argnums=0)(phi, theta)

Would only differentiate w.r.t. phi (the 0th argument). Then you could use optax to update the parameters like what’s in this example taken from here.

import pennylane as qml
from jax import numpy as jnp
import jax
import optax

learning_rate = 0.15

dev = qml.device("default.qubit", wires=1, shots=None)

@jax.jit
@qml.qnode(dev, interface="jax")
def energy(a):
    qml.RX(a, wires=0)
    return qml.expval(qml.PauliZ(0))

optimizer = optax.adam(learning_rate)

params = jnp.array(0.5)
opt_state = optimizer.init(params)

for _ in range(200):
    grads = jax.grad(energy)(params)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)

Let me know if that helps!