Я строю свою модель в Google Colab.
Я создал собственную матрицу встраивания
import torchtext.vocab as vocab
custom_embeddings = vocab.Vectors(name = 'custom_embeddings.txt')
TEXT.build_vocab(train_data, vectors = custom_embeddings)
Вот код для класса Encoder:
class Encoder(nn.Module):
def __init__(self,
input_dim,
hid_dim,
n_layers,
n_heads,
pf_dim,
dropout,
device,
max_length = 100):
super().__init__()
self.device = device
self.tok_embedding = nn.Embedding(input_dim, hid_dim)
# step added for custom embedding
self.tok_embedding.weight.data.copy_(custom_embeddings)
self.pos_embedding = nn.Embedding(max_length, hid_dim)
self.layers = nn.ModuleList([EncoderLayer(hid_dim,
n_heads,
pf_dim,
dropout,
device)
for _ in range(n_layers)])
self.dropout = nn.Dropout(dropout)
self.scale = torch.sqrt(torch.FloatTensor([hid_dim])).to(device)
def forward(self, src, src_mask):
#src = [batch size, src len]
#src_mask = [batch size, src len]
batch_size = src.shape[0]
src_len = src.shape[1]
pos = torch.arange(0, src_len).unsqueeze(0).repeat(batch_size, 1).to(self.device)
#pos = [batch size, src len]
src = self.dropout((self.tok_embedding(src) * self.scale) + self.pos_embedding(pos))
#src = [batch size, src len, hid dim]
for layer in self.layers:
src = layer(src, src_mask)
#src = [batch size, src len, hid dim]
return src
Теперь, когда я пытаюсь создать объект Encoder, я получаю сообщение об ошибке для пользовательского внедрения, которое я использовал.
enc = Encoder(INPUT_DIM,
HID_DIM,
ENC_LAYERS,
ENC_HEADS,
ENC_PF_DIM,
ENC_DROPOUT,
device)
Ошибка описания:
TypeError Traceback (most recent call last)
<ipython-input-72-06d3631c029b> in <module>()
18 ENC_PF_DIM,
19 ENC_DROPOUT,
---> 20 device)
21
22 dec = Decoder(OUTPUT_DIM,
<ipython-input-59-6c2f23451d01> in __init__(self, input_dim, hid_dim, n_layers, n_heads, pf_dim, dropout, device, max_length)
16
17 # step added for custom embedding
---> 18 self.tok_embedding.weight.data.copy_(custom_embeddings)
19
20 self.pos_embedding = nn.Embedding(max_length, hid_dim)
TypeError: copy_(): argument 'other' (position 1) must be Tensor, not Vectors
Не могли бы вы помочь мне исправить эту ошибку?
Заранее спасибо!