Issue using lightning device with JAX

Hi!
I was able to run your code without errors by adding the @qml.qjit decorator on top of the qml.node.

def make_circuit(dev, n_qubits):
    @qml.qjit
    @qml.qnode(dev)
    def circuit(x):
        for i in range(n_qubits):
            qml.RX(x[i], wires=i)
        return qml.expval(qml.PauliZ(wires=0))
    return jax.vmap(circuit, in_axes=0)

You can see this post where a similar issue was discussed.