Как я должен понимать аргументы nn.Embeddings num_embeddings и embedding_dim? - PullRequest
1 голос
/ 09 ноября 2019

Я пытаюсь привыкнуть к классу Embedding в модуле PyTorch nn.

Я заметил, что довольно много других людей сталкивались с такой же проблемой, как и я, и поэтому разместили вопросына дискуссионном форуме PyTorch и по переполнению стека, но у меня все еще есть некоторая путаница.

Согласно официальной документации , передаются следующие аргументы: num_embeddings и embedding_dimкаждый из которых относится к тому, насколько велик наш словарь (или словарь) и сколько измерений мы хотим, чтобы наши вложения были соответственно.

Что меня смущает, так это то, как именно я должен их интерпретировать. Например, небольшой практический код, который я запустил:

import torch
import torch.nn as nn


embedding = nn.Embedding(num_embeddings=10, embedding_dim=3)

a = torch.LongTensor([[1, 2, 3, 4], [4, 3, 2, 1]]) # (2, 4)

b = torch.LongTensor([[1, 2, 3], [2, 3, 1], [4, 5, 6], [3, 3, 3], [2, 1, 2],
                      [6, 7, 8], [2, 5, 2], [3, 5, 8], [2, 3, 6], [8, 9, 6],
                      [2, 6, 3], [6, 5, 4], [2, 6, 5]]) # (13, 3)

c = torch.LongTensor([[1, 2, 3, 2, 1, 2, 3, 3, 3, 3, 3],
                      [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]]) # (2, 11)

Когда я запускаю a, b и c через переменную embedding, я получаю встроенные результаты фигур (2, 4, 3), (13, 3, 3), (2, 11, 3).

Что меня смущает, так это то, что я подумал, что количество сэмплов у нас превышает заранее определенное количество вложений, мы должны получить ошибку? Поскольку определенное мной embedding имеет 10 вложений, не должно ли b выдавать ошибку, поскольку это тензор, содержащий 13 слов измерения 3?

1 Ответ

1 голос
/ 09 ноября 2019

В вашем случае, вот как интерпретируется ваш входной тензор:

a = torch.LongTensor([[1, 2, 3, 4], [4, 3, 2, 1]]) # 2 sequences of 4 elements

Более того, так интерпретируется ваш слой встраивания:

embedding = nn.Embedding(num_embeddings=10, embedding_dim=3) # 10 distinct elements and each those is going to be embedded in a 3 dimensional space

Так что это не такНе имеет значения, если ваш входной тензор содержит более 10 элементов, если они находятся в диапазоне [0, 9]. Например, если мы создадим тензор из двух элементов, таких как:

d = torch.LongTensor([[1, 10]]) # 1 sequence of 2 elements

, мы получим следующую ошибку, когда мы пропустим этот тензор через слой внедрения:

RuntimeError:индекс вне диапазона: пытался получить доступ к индексу 10 вне таблицы с 9 строками

Подводя итог num_embeddings - это общее количество уникальных элементов в словаре, а embedding_dim - это размер каждого вложенноговектор один раз прошел через слой вложения. Следовательно, у вас может быть тензор из 10+ элементов, если каждый элемент в тензорном диапазоне находится в диапазоне [0, 9], поскольку вы определили размер словаря в 10 элементов.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...