Hi!
As I mentioned in my previous post, I’m trying to write a quantum circuit to learn MNIST classifications. Because images are too large for a quantum circuit, the data is first run through an autoencoder to reduce the dimensionality from 28 * 28 to just one vector of length 10. I then run through a circuit with 10 wires, and use the expectation of each wire as the score for that class. I’ve got it all working, but it’s pretty slow.
This is what the circuit, cost and grad code look like. I’ve omitted setup and imports for the sake of brevity, but can post that if it might affect things.
# x will be a length ENCODING_SIZE vector
# that represents the encoding of a MNIST image
# thetas is of size 2 * NUM_QUBITS
@qml.qnode(dev)
def circuit(x, thetas):
for i in range(ENCODING_SIZE):
RX(x[i], wires=i)
for i in range(NUM_QUBITS - 1):
CNOT(wires=[i, i+1])
for i in range(NUM_QUBITS):
RX(thetas[i], wires=i)
for i in range(NUM_QUBITS, 2 * NUM_QUBITS):
RY(thetas[i], wires=(i - NUM_QUBITS))
return tuple(qml.expval.PauliZ(wires=i) for i in range(NUM_QUBITS))
# X is of size (b, 10), actual_labels is size (b,)
# thetas if of size 2 * NUM_QUBITS.
# implements cross-entropy classification loss
# as described here:
# https://pytorch.org/docs/stable/nn.html#crossentropyloss
# with numerical stability
def cost(X, actual_labels, thetas):
b = X.shape[0]
yhats = []
for i in range(b):
yhat = circuit(X[i], thetas)
yhats.append(yhat)
st = np.stack(yhats)
actual_class_vals = st[range(b), actual_labels]
shifted = st - np.max(st, axis=1)[:, np.newaxis]
the_sum = np.log(np.sum(np.exp(shifted), axis=1))
return np.mean(-actual_class_vals + the_sum)
# loaded the data in batches of size 4, so
# X is of size (4, 10)
X = encoder(inputs.view(len(labels), -1))
start = time.time()
qml.grad(cost, argnum=2)(X.numpy(), labels.numpy(), thetas)
print(time.time() - start)
this operation takes about 200 seconds (and scales linearly with the size of the batch, so 50 seconds per example). at this speed, it would take a month to do the entire 60000 image dataset. Is there anything I can do to speed this up, or is this just the nature of the implementation and there is not much that can be done about this? the reason I ask is because this is for a class project (CS269Q at Stanford) and we only have about two weeks remaining.
I have two thoughts so far on why it is slow:
- the cost function is semi complicated so calculating the gradient is quite a hassle. however, I feel like in any classification task it’s going to be like this. should I try to switch to some dataset on which I can perform regression instead?
- there are too many wires. I could try to reduce the number of wires, but the reason I picked 10 was so each wire could correspond to one of the 10 classes. if I need to reduce the number of wires to say 5, how would I classify after that? I guess I could attach it to a simple matrix multiplication that maps from the 5 wires to the 10 classes, and also learn that 5 x 10 matrix. the only problem is that really increases the number of parameters to learn, which may or may not be a problem. I’m not sure.
any thoughts on this would be very much appreciated. thanks so much!