Hi @rc17782 ,
I don’t have all of the info to test your code locally so I’m not sure how the optimization is happening.
If it helps I made this code example for someone else, where I broke down the batches into individual inputs within the forward pass, but then the optimization was done over batches.
If this doesn’t answer your question, would you be able to share a minimal reproducible example of your code?
I hope this helps!