Это стандартная линейная проекция. Вы можете просто добавить nn.Linear(2 * model_dim, model_dim)
, где model_dim
- размерность RNN.
Кодер является двунаправленным, с одним RNN в обоих направлениях, имеющим выходной размер model_dim
. Декодер работает только в прямом направлении, поэтому он имеет состояния только model_dim
размеров. На самом деле он сохраняет много параметров в центре внимания, потому что проецирует для ключей и значений только половину размера, потому что они проецируются с model_dim
вместо 2 * model_dim
.