Incompatible function arguments error on lightning.qubit with JAX

If you really need true batching/broadcasting, I suggest using default.qubit or you can maybe try using jax.pmap:

import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=6'

import jax
jax.devices("cpu")
import pennylane as qml
from pennylane import numpy as np

from jax.config import config as jax_config

jax_config.update("jax_enable_x64", True)

dev = qml.device('lightning.qubit', wires=2)

@jax.jit
@qml.qnode(dev, diff_method="parameter-shift")
def circuit(inputs):
    qml.RX(inputs, 0)
    return [qml.expval(qml.PauliZ(q)) for q in range(2)]

inputs = np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6])

batched_qnode = jax.pmap(circuit, in_axes=(0))
print(batched_qnode(inputs))

Although, the overall utility of this^ will be debatable.