Hello.
I am having some problems creating a keras model with a quantum layer. I obtain the following Warning when I try to train my model:
WARNING:tensorflow:You are casting an input of type complex128 to an incompatible dtype float32. This will discard the imaginary part and may not be what you intended.
My code is the following one:
import pennylane as qml
import sklearn.datasets
from qiskit_machine_learning import datasets
import qiskit_machine_learning as qiskitml
from sklearn.model_selection import train_test_split
n_qubits = 5
dev = qml.device("default.qubit", wires=n_qubits)
@qml.qnode(dev)
def qnode(inputs, weights):
qml.templates.AngleEmbedding(np.pi*inputs, wires=range(n_qubits))
qml.templates.StronglyEntanglingLayers(weights, wires=range(n_qubits))
return [qml.expval(qml.PauliZ(i)) for i in range(n_qubits)]
shape = qml.StronglyEntanglingLayers.shape(n_layers=1, n_wires=5)
weight_shapes = {"weights": shape}
inputs = np.random.rand(n_qubits).astype(comp_dtype)
weights = np.random.rand(2, n_qubits, 3).astype(comp_dtype)
X_train, X_test, Y_train, Y_test =train_test_split(x, y,test_size=0.15, random_state=0)
X = tf.constant(X_train,dtype=comp_dtype)
Y = tf.constant(Y_train,dtype=comp_dtype)
X_test= tf.constant(X_test,dtype=comp_dtype)
Y_test= tf.constant(Y_test,dtype=comp_dtype)
print(qnode(X_train[0],weights))
q_layer = qml.qnn.KerasLayer(qnode, weight_shapes, output_dim=n_qubits, dtype=comp_dtype)
q_layer.build(2)
q_model = tf.keras.models.Sequential()
q_model.add(tf.keras.layers.Dense(n_qubits, activation='sigmoid', input_dim=5))
q_model.add(q_layer)
q_model.add(tf.keras.layers.Dense(1, activation='softmax'))
q_model.summary()
opt = tf.keras.optimizers.Adam(learning_rate=0.05)
q_model.compile(loss='huber', optimizer=opt, metrics=["accuracy"])
q_model.fit(X, Y, epochs=8, batch_size=5, verbose=1, validation_data=(X_test, Y_test))
And my versions of tensorflow and pennylane are the following ones:
tensorflow == Version: 2.12.0
Pennylane == Version: 0.29.1
The problem is that this warning appears many times, so if someone can help me solving it or just ignoring it would be perfect.
Thanks in advance.