Hello! If applicable, put your complete code example down below. Make sure that your code:
- is 100% self-contained — someone can copy-paste exactly what is here and run it to
reproduce the behaviour you are observing - includes comments
def info_nce_loss_torch(features, n_views, batch_size):
features = torch.from_numpy(features)
labels = torch.cat([torch.arange(batch_size) for i in range(n_views)], dim=0)
labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()
# labels = labels.to(self.args.device)
If you want help with diagnosing an error, please put the full error message below:
Cell In[3], line 52, in info_nce_loss_torch(features, n_views, batch_size)
51 def info_nce_loss_torch(features, n_views, batch_size):
---> 52 features = torch.from_numpy(features)
53 labels = torch.cat([torch.arange(batch_size) for i in range(n_views)], dim=0)
54 labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()
TypeError: expected np.ndarray (got ArrayBox)
And, finally, make sure to include the versions of your packages. Specifically, show us the output of qml.about()
.