Что является альтернативой PyTorch для Keras input_shape, output_shape, get_weights, get_config и summary - PullRequest
0 голосов
/ 23 ноября 2018

В Keras, после создания модели, мы можем видеть ее входные и выходные формы, используя model.input_shape, model.output_shape.Для весов и конфигурации мы можем использовать model.get_weights() и model.get_config() соответственно.

Каковы аналогичные альтернативы для PyTorch?Также есть ли какие-либо другие функции, которые нам нужно знать для проверки модели PyTorch?

Чтобы получить сводку в PyTorch, мы печатаем модель print(model), но это дает меньшую информацию, чем model.summary().Есть ли лучшее резюме для PyTorch?

1 Ответ

0 голосов
/ 23 ноября 2018

В pytorch нет метода "model.summary ()".Вам нужно использовать встроенные методы и поля модели.

Например, я настроил модель inception_v3.Чтобы получить информацию, мне нужно использовать много разных полей.Например:

IN:

print(model) # print network architecture

OUT

Inception3(
  (Conv2d_1a_3x3): BasicConv2d(
    (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
    (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (Conv2d_2a_3x3): BasicConv2d(
    (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (Conv2d_2b_3x3): BasicConv2d(
    (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (Conv2d_3b_1x1): BasicConv2d(
    (conv): Conv2d(64, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(80, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (Conv2d_4a_3x3): BasicConv2d(
    (conv): Conv2d(80, 192, kernel_size=(3, 3), stride=(1, 1), bias=False)
   ...

IN:

for i in model.state_dict().keys():
    print(i) 
#print keys of dict with values of learned weights, bias, parameters

OUT:

    Conv2d_1a_3x3.conv.weight
    Conv2d_1a_3x3.bn.weight
    Conv2d_1a_3x3.bn.bias
    Conv2d_1a_3x3.bn.running_mean
    Conv2d_1a_3x3.bn.running_var
    Conv2d_1a_3x3.bn.num_batches_tracked
    Conv2d_2a_3x3.conv.weight
    Conv2d_2a_3x3.bn.weight
    Conv2d_2a_3x3.bn.bias
    Conv2d_2a_3x3.bn.running_mean 
    ...

Поэтому, если я хочу получить веса для слоя CNN в Conv2d_1a_3x3, я ищу ключ "Conv2d_1a_3x3.conv.weight":

print("model.save_dict()["Conv2d_1a_3x3.conv.weight"])

OUT:

tensor([[[[-0.2103, -0.3441, -0.0344],
          [-0.1420, -0.2520, -0.0280],
          [ 0.0736,  0.0183,  0.0381]],

         [[ 0.1417,  0.1593,  0.0506],
          [ 0.0828,  0.0854,  0.0186],
          [ 0.0283,  0.0144,  0.0508]],
...

Если вы хотите увидеть используемые гиперпараметры из оптимизатора:

optimizer.param_groups

OUT:

[{'dampening': 0,
  'lr': 0.01,
  'momentum': 0.01,
  'nesterov': False,
  'params': [Parameter containing:
   tensor([[[[-0.2103, -0.3441, -0.0344],
             [-0.1420, -0.2520, -0.0280],
             [ 0.0736,  0.0183,  0.0381]],
          ...
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...