Как получить выход из определенного слоя из модели PyTorch? - PullRequest
0 голосов
/ 13 октября 2018

Как извлечь элементы из определенного слоя из предварительно обученной модели PyTorch (такой как ResNet или VGG), не делая повторный проход вперед?

Ответы [ 2 ]

0 голосов
/ 21 мая 2019

Я пытаюсь извлечь особенности определенного слоя предварительно обученной модели.Код поиска, основанный на ответе bryant1410, работает, однако значения template_feature_map изменились, и я ничего не сделал с ним.

Вывод 6-го слоя модели должен иметь отрицательные значения, как показывает первая печать (template_feature_map),Но отрицательные значения, которые должны поддерживаться во второй печати (template_feature_map), заменяются нулями, я не знаю почему.Если вы знаете механизм этого, скажите, пожалуйста, как сохранить отрицательные значения.

vgg_feature = models.vgg13(pretrained=True).features
template_feature_map=None
def save_template_feature_map(module, input, output):
    global template_feature_map
    template_feature_map=output
    print(template_feature_map)
template_handle = vgg_feature[5].register_forward_hook(save_template_feature_map)
vgg_feature(template[0])
print(template_feature_map)

Вывод двух print (template_feature_map):

tensor([[[[-5.7389e-01, -2.7154e+00, -4.0990e+00,  ...,  4.1902e+00,
            3.1757e+00,  2.2461e+00],
          [-2.2217e+00, -4.3395e+00, -6.8158e+00,  ..., -1.4454e+00,
            9.8012e-01, -2.3653e+00],
          [-4.1940e+00, -6.3235e+00, -6.8422e+00,  ..., -2.8329e+00,
            2.5570e+00, -2.7704e+00],
          ...,
          [-3.3250e+00,  1.3792e-01,  5.4926e+00,  ..., -4.1722e+00,
           -6.1008e-01, -2.6037e+00],
          [ 1.5377e+00,  6.0671e-01,  2.0974e+00,  ...,  1.2441e+00,
            1.5033e+00, -2.7246e+00],
          [ 6.8857e-01, -3.5160e-02,  6.7858e-01,  ...,  1.2052e+00,
            1.4533e+00, -1.4160e+00]],

         [[ 6.8798e-01,  1.6971e+00,  2.1629e+00,  ...,  3.1701e-01,
            8.5424e-01,  2.8768e+00],
          [ 1.4013e+00,  2.7217e+00,  2.1476e+00,  ...,  3.1156e+00,
            4.4858e+00,  3.6936e+00],
          [ 3.1807e+00,  2.2245e+00,  2.4665e+00,  ...,  1.3838e+00,
            1.0580e-02, -3.1445e-03],
          ...,
          [-4.7298e+00, -3.3037e+00, -1.2982e+00,  ...,  2.3266e-01,
            6.7711e+00,  3.8166e+00],
          [-4.7972e+00, -5.4591e+00, -2.5201e+00,  ...,  3.7584e+00,
            5.1524e+00,  2.3072e+00],
          [-2.4306e+00, -2.8033e+00, -2.0912e+00,  ...,  1.9888e+00,
            2.0582e+00,  1.9266e+00]],

         [[-4.4257e+00, -4.6331e+00, -3.3580e-03,  ..., -8.2233e+00,
           -7.4645e+00, -1.7361e+00],
          [-4.5593e+00, -8.4195e+00, -8.8428e+00,  ..., -6.7950e+00,
           -1.4665e+01, -2.5335e+00],
          [-2.3481e+00, -3.8543e+00, -3.5965e+00,  ..., -1.5105e+00,
           -1.6923e+01, -5.9852e+00],
          ...,
          [-8.0165e+00,  8.0185e+00,  6.5506e+00,  ...,  5.3241e+00,
            3.3854e+00, -1.6342e+00],
          [-1.3689e+01, -2.2930e+00,  4.7097e+00,  ...,  3.2021e+00,
            2.9208e+00, -8.0228e-01],
          [-1.3055e+01, -1.1470e+01, -8.4442e+00,  ...,  1.8155e-02,
           -6.2866e-02, -2.0333e+00]],

         ...,

         [[ 3.4622e+00, -1.2417e+00, -5.0749e+00,  ...,  5.3184e+00,
            1.4744e+01,  8.3968e+00],
          [-2.7820e+00, -9.1911e+00, -1.1069e+01,  ...,  2.5380e+00,
            9.8336e+00,  4.0623e+00],
          [-3.9794e+00, -1.0140e+01, -9.9133e+00,  ...,  3.0999e+00,
            5.5936e+00,  2.5775e+00],
          ...,
          [ 2.0299e+00,  2.1304e-01, -2.2307e+00,  ...,  1.1388e+01,
            8.8098e+00,  1.8991e+00],
          [ 8.0663e-01, -1.5073e+00,  3.3977e-01,  ...,  8.5316e+00,
            4.9923e+00, -3.6818e-01],
          [-3.5146e+00, -7.2647e+00, -5.4331e+00,  ..., -1.9781e+00,
           -3.4463e+00, -4.9034e+00]],

         [[-3.2915e+00, -7.3263e+00, -6.8458e+00,  ...,  2.3122e+00,
            9.7774e-01, -1.3498e+00],
          [-4.5396e+00, -8.6832e+00, -8.8582e+00,  ...,  7.1535e-02,
           -4.1133e+00, -4.4045e+00],
          [-4.8781e+00, -7.0239e+00, -4.7350e+00,  ..., -3.6954e+00,
           -9.6687e+00, -8.8289e+00],
          ...,
          [-4.7072e+00, -4.4823e-01,  1.7099e+00,  ...,  3.7923e+00,
            1.6887e+00, -4.3305e+00],
          [-5.5120e+00, -3.2324e+00,  2.3594e+00,  ...,  4.6031e+00,
            1.8856e+00, -4.0147e+00],
          [-5.1355e+00, -5.5335e+00, -1.7738e+00,  ...,  1.6159e+00,
           -1.3950e+00, -4.1055e+00]],

         [[-2.0252e+00, -2.3971e+00, -1.6477e+00,  ..., -3.3740e+00,
           -4.9965e+00, -2.1219e+00],
          [-7.6059e-01, -3.3901e-01, -1.8980e-01,  ..., -4.3286e+00,
           -7.1350e+00, -3.9186e+00],
          [ 8.4101e-01,  1.3403e+00,  2.5821e-01,  ..., -5.1847e+00,
           -7.1829e+00, -3.7724e+00],
          ...,
          [-6.0619e+00, -5.6475e+00, -1.6446e+00,  ..., -9.2322e+00,
           -9.1981e+00, -5.5239e+00],
          [-7.4606e+00, -7.6054e+00, -5.8401e+00,  ..., -7.6998e+00,
           -6.4111e+00, -2.9374e+00],
          [-6.4147e+00, -7.2813e+00, -6.1880e+00,  ..., -4.6726e+00,
           -3.1090e+00, -7.8383e-01]]]], grad_fn=<MkldnnConvolutionBackward>)
tensor([[[[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 4.1902e+00,
           3.1757e+00, 2.2461e+00],
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           9.8012e-01, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           2.5570e+00, 0.0000e+00],
          ...,
          [0.0000e+00, 1.3792e-01, 5.4926e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [1.5377e+00, 6.0671e-01, 2.0974e+00,  ..., 1.2441e+00,
           1.5033e+00, 0.0000e+00],
          [6.8857e-01, 0.0000e+00, 6.7858e-01,  ..., 1.2052e+00,
           1.4533e+00, 0.0000e+00]],

         [[6.8798e-01, 1.6971e+00, 2.1629e+00,  ..., 3.1701e-01,
           8.5424e-01, 2.8768e+00],
          [1.4013e+00, 2.7217e+00, 2.1476e+00,  ..., 3.1156e+00,
           4.4858e+00, 3.6936e+00],
          [3.1807e+00, 2.2245e+00, 2.4665e+00,  ..., 1.3838e+00,
           1.0580e-02, 0.0000e+00],
          ...,
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 2.3266e-01,
           6.7711e+00, 3.8166e+00],
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 3.7584e+00,
           5.1524e+00, 2.3072e+00],
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 1.9888e+00,
           2.0582e+00, 1.9266e+00]],

         [[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          ...,
          [0.0000e+00, 8.0185e+00, 6.5506e+00,  ..., 5.3241e+00,
           3.3854e+00, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 4.7097e+00,  ..., 3.2021e+00,
           2.9208e+00, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 1.8155e-02,
           0.0000e+00, 0.0000e+00]],

         ...,

         [[3.4622e+00, 0.0000e+00, 0.0000e+00,  ..., 5.3184e+00,
           1.4744e+01, 8.3968e+00],
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 2.5380e+00,
           9.8336e+00, 4.0623e+00],
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 3.0999e+00,
           5.5936e+00, 2.5775e+00],
          ...,
          [2.0299e+00, 2.1304e-01, 0.0000e+00,  ..., 1.1388e+01,
           8.8098e+00, 1.8991e+00],
          [8.0663e-01, 0.0000e+00, 3.3977e-01,  ..., 8.5316e+00,
           4.9923e+00, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00]],

         [[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 2.3122e+00,
           9.7774e-01, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 7.1535e-02,
           0.0000e+00, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          ...,
          [0.0000e+00, 0.0000e+00, 1.7099e+00,  ..., 3.7923e+00,
           1.6887e+00, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 2.3594e+00,  ..., 4.6031e+00,
           1.8856e+00, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 1.6159e+00,
           0.0000e+00, 0.0000e+00]],

         [[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [8.4101e-01, 1.3403e+00, 2.5821e-01,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          ...,
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00]]]], grad_fn=<ThresholdBackward1>)
0 голосов
/ 13 октября 2018

Вы можете зарегистрировать перемотку вперед на конкретном нужном слое.Что-то вроде:

def some_specific_layer_hook(module, input_, output):
    pass  # the value is in 'output'

model.some_specific_layer.register_forward_hook(some_specific_layer_hook)

model(some_input)

Например, чтобы получить res5c выход в ResNet, вы можете использовать переменную nonlocal (или global в Python 2):

res5c_output = None

def res5c_hook(module, input_, output):
    nonlocal res5c_output
    res5c_output = output

resnet.layer4.register_forward_hook(res5c_hook)

resnet(some_input)

# Then, use `res5c_output`.
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...