Salve,
I am having a real difficulty breaking this apart. There is a lot pressed together.
Could someone help me to change this to a script that runs inference on an individual image? Rather than this subplot and grid thing?
def visualize_model(model, num_images=6, fig_name="Predictions"):
images_so_far = 0
_fig = plt.figure(fig_name)
model.eval()
with torch.no_grad():
for _i, (inputs, labels) in enumerate(dataloaders["validation"]):
inputs = inputs.to(device)
labels = labels.to(device)
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
for j in range(inputs.size()[0]):
images_so_far += 1
ax = plt.subplot(num_images // 2, 2, images_so_far)
ax.axis("off")
ax.set_title("[{}]".format(class_names[preds[j]]))
imshow(inputs.cpu().data[j])
if images_so_far == num_images:
return
Just simple single image inference or loop over folder?
Merci!