Умножить карту объектов на скаляр в pytorch - PullRequest
0 голосов
/ 22 января 2020

У меня проблема с бинарной классификацией, в наборе данных есть изображение и переменная, у меня есть идея сравнить изображение и переменную вместе.

Каждый раз, когда я прохожу конво-слой, я хочу умножить весовой скаляр для всей карты объектов, где весовой скаляр вычисляется из слоя af c.

Например, предположим, что размер пакета равен 8, есть два тензора x1 и x2, где * Размер 1007 * равен (8,3,224,224), а размер x2 равен (8,16).

import torch 
from torch.nn import Module, Sequential
from torch.nn import Conv2d, BatchNorm2d, ReLU, MaxPool2d, Softmax, Linear
import numpy
batch_size = 8
x1 = torch.rand(batch_size*3*224*224).view(batch_size,3,224,224)
x2 = torch.rand(batch_size*16).view(batch_size,16)

Я определяю conv-layer и f c -layer и вычисляю вывод из изображения и переменной.

conv_01 = Conv2d(in_channels=3, out_channels= 9, kernel_size=3, stride=1, padding=1)
linear_02 = Linear(16, 1)
c1 = conv_01(x1) ## torch.Size([8, 9, 224, 224])
c2 = linear_02(x2) ## torch.Size([8, 1])

Проблема заключается в том, чтобы написать подходящий код, как показано ниже.

## I want to do like blow
## c1[0,:,:,:] = c1[0,:,:,:] * c2[0,0] # 1st data in the mini-batch
## c1[1,:,:,:] = c1[1,:,:,:] * c2[0,1] # 2nd data in the mini-batch
## c1[2,:,:,:] = c1[2,:,:,:] * c2[0,2]
## c1[3,:,:,:] = c1[3,:,:,:] * c2[0,3]
## c1[4,:,:,:] = c1[4,:,:,:] * c2[0,4]
## c1[5,:,:,:] = c1[5,:,:,:] * c2[0,5]
## c1[6,:,:,:] = c1[6,:,:,:] * c2[0,6]
## c1[7,:,:,:] = c1[7,:,:,:] * c2[0,7]
## output is a (8, 9, 224, 224)
## and do more layer like this operation

Я уже взглянул на карту Multiply Feature с помощью обучаемого скалярное . Но эта поддержка поддерживается только тогда, когда размер пакета равен 1, но в моем случае размер пакета больше 1. Как написать подходящий код для функции forward в моем случае? Большое спасибо.

1 Ответ

0 голосов
/ 22 января 2020
result = c1 * c2.reshape((-1,1,1,1))

Вы можете изменить свою c2 форму torch.Size([8, 1]) до torch.Size([8, 1, 1, 1]), используя torch.reshape, так что вы можете сделать умножение на c1 shape torch.Size([8, 9, 224, 224])

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...