Loading the load_state_dict from quantum_weights.pt file

Hi Andrea,
Could you help me understand the following error when loading the pre-trained quantum weights? I have been trying to adapt the Quantum transfer learning code from your paper to my problem statement.
As explained by you, I trained my model_hybrid in pennylane simulator and saved the trained model as a .pt file. When I try to load the saved sate_dict in order to execute the model on the actual quantum hardware it gives me a "keys mismatch" error as follows:


P.S.: I am following your exact code for my work-- the ants vs. bees example executed perfectly on the IBMQ machine as stated in my previous reply to Quantum transfer learning code (Mari et al., 2019) - IBMQDevice endless execution. I have no issues with loading the quantum_weights.pt file there. That code executed perfectly well. My problem is with another image classification problem at hand. How did you save the pre-trained model? I used the following function to do so:
torch.save(model.state_dict(), PATH)

Are there any specific parameters that you specified while saving the trained model?

1 Like

Hi @angelinaG,
I think the problem is that you saved the parameters of full model (model_hybrid) and you are trying to load them on the last layer of the model (model_hybrid.fc). You have two alternative options:

  1. replace model_hybrid.fc.load_state_dict() with model_hybrid.load_state_dict(),
  2. keep model_hybrid.fc.load_state_dict() as it is, but repeat the training phase and save model_hybrid.fc.state_dict() instead of model_hybrid.state_dict()

This should solve your problem, Andrea

2 Likes

Thank you very much @andreamari for your prompt response. I highly appreciate it.
I shall try this out.

Hi @andreamari!
Thank you once again!
I tried option 2 and it worked :slight_smile: