I have followed this tutorial https://www.youtube.com/watch?v=U0s0f995w14 to create a minified version of a transformer architecture but I am confused about the final shape of the output.
on the final lines (slightly modified):
print(x.shape)
print(trg[:, :-1].shape)
out = model(x, trg[:, :-1])
print(out.shape)
The output shapes don't seem to make sense
torch.Size([2, 9]) #input sentence (num_examples, num_tokens)
torch.Size([2, 7]) #generated sentence so far (num_examples, num_tokens_generated)
torch.Size([2, 7, 10]) # probabilities for next token (num_examples, ???, size_vocab)
The transformer is supposed to predict the next token across 2 training examples (which is why theres a two for the number of examples and a 10 for the size of the vocab), by generating probabilites for each token in the vocab. But I can't make sense of why theres a 7 there. The only explanation I can come up with is that it outputs all predictions simulatenously, but that would require feeding the outputs iteratively through the transformer, but that never happens (see lines 267-270).
So is there a mistake or am I not understanding something correctly? What is that output shape supposed to represent?
Can somebody make sense of this?