Hi @kekcikon, Welcome to the forum!
Could you provide a minimal working version of your code? It could be helpful so I can see how you set up everything.
I also wondering if you had seen this demo using the special unitary with JAX. Although, I see that you are using Torch. And also this and this past questions from the forum.
For now, I hope this helps. If not, let me know and we’ll keep the conversation going.