что такое определение класса nn.Linear в pytorch?
Из документации :
CLASS torch.nn.Linear(in_features, out_features, bias=True)
Применяет линейное преобразование к входящим данным: y = xW ^ T + b
Параметры:
- in_features - размер каждой входной выборки
- out_features - размер каждой выходной выборки
- смещение - Если установлено значение False, слой не будет изучатьаддитивное смещение.По умолчанию: True
Обратите внимание, что веса W (формы (out_features, in_features)) и смещения b (формы (out_features)) в линейном уравнении инициализируются случайным образом и могут быть измененыпозже (например, во время обучения сети).
Давайте рассмотрим конкретный пример:
import torch
from torch import nn
m = nn.Linear(2, 1)
input = torch.tensor([[1.0, -1.0]])
output = m(input)
Параметры инициализируются случайным образом
>>> m.weight
tensor([[0.2683, 0.2599]])
>>> m.bias
tensor([0.6741])
Выходные данныевычисляется как 1.0 * 0.2683 - 1.0 * 0.2599 + 0.6741 = 0.6825
>>> print(output)
tensor([[0.6825]]
В вашей сети есть три слоя: входной слой с 784 узлами, один скрытый слой с 256 узлами и выходной слой с 10 узлами.