загрузка предварительно обученной модели в pytorch - PullRequest
0 голосов
/ 09 октября 2019

Во-первых, я хотел бы извиниться, этот вопрос может звучать глупо, но я новичок в углубленном изучении. Кто-нибудь может объяснить мне следующие строки кода, которые использовались для загрузки предварительно обученной модели в PyTorch?

# Retrieving model parameters from checkpoint.
vocab_size = checkpoint["model"]["_word_embedding.weight"].size(0)
embedding_dim = checkpoint["model"]['_word_embedding.weight'].size(1)
hidden_size = checkpoint["model"]["_projection.0.weight"].size(0)
num_classes = checkpoint["model"]["_classification.4.weight"].size(0)

Я не могу понять проекцию, вес, классификацию, размер (0), размер(1) в приведенном выше тексте.

1 Ответ

1 голос
/ 10 октября 2019
import torch
import torch.nn as nn


class Model(nn.Module):

    def __init__(self):
        super(Model, self).__init__()

        vocab_size = 10000
        embed_size = 100
        # word embedding layer
        self._word_embedding = nn.Embedding(vocab_size, embed_size)
        # linear transformation layers (no bias)
        self._projection = nn.ModuleList([nn.Linear(100, 50, bias=False)
                                          for i in range(2)])
        # linear transformation layers (no bias)
        self._classification = nn.ModuleList([nn.Linear(50, 50, bias=False)
                                              for i in range(4)])

    def forward(self):
        return


model = Model()
checkpoint = {
    'model': model.state_dict()  # OrderedDict
}

# _word_embedding.weight --> torch.Size([10000, 100])
# _projection.0.weight --> torch.Size([50, 100])
# _projection.1.weight --> torch.Size([50, 100])
# _classification.0.weight --> torch.Size([50, 50])
# _classification.1.weight --> torch.Size([50, 50])
# _classification.2.weight --> torch.Size([50, 50])
# _classification.3.weight --> torch.Size([50, 50])

for name, param in checkpoint['model'].items():
    print(name, '-->', param.size()) # see above

# similarly, we can print as follows
print(checkpoint["model"]["_word_embedding.weight"].size(0)) # 10000
print(checkpoint["model"]["_word_embedding.weight"].size(1)) # 100
print(checkpoint["model"]["_projection.0.weight"].size(0)) # 50
print(checkpoint["model"]["_classification.0.weight"].size(0)) # 50

Подготовил пример, который поможет вам понять значение этих четырех строк.

Я не могу понять проекцию, вес, классификацию, размер (0), размер (1)в приведенном выше тексте.

  • проекция : слой нейронной сети
  • классификация : слой нейронной сети
  • вес : весовая матрица соответствующих NN слоев
  • размер (0) : размер первого измерения весовой матрицы
  • size (1) : размер второго измерения весовой матрицы
...