Hi @quantumlover
Welcome to the forum!
I believe this is related to the JAX version you are using not being compatible with Pennylane yet.
I ran the code myself using JAX v0.4.28 and it works. See this from another previous post.
Let me know if downgrading JAX version works for you.