Я пытаюсь обучить одну модель CNN с помощью Pytorch, чтобы выходные данные вели себя по-разному для разных типов входов. (т. е. если входные изображения являются человеческими существами, он выводит шаблон A, но если входными данными являются некоторые другие животные, он выводит шаблон B).
После некоторого онлайн-поиска кажется, что сиамская сеть связана с этим , Итак, у меня есть следующие 2 вопроса:
(1) Является ли сиамская сеть действительно хорошим способом обучения такой модели?
(2) Как мне реализовать реализацию? код в pytorch?
class SiameseNetwork(nn.Module):
def __init__(self):
super(SiameseNetwork, self).__init__()
self.cnn1 = nn.Sequential(
nn.ReflectionPad2d(1),
nn.Conv2d(1, 4, kernel_size=3),
nn.ReLU(inplace=True),
nn.BatchNorm2d(4),
nn.ReflectionPad2d(1),
nn.Conv2d(4, 8, kernel_size=3),
nn.ReLU(inplace=True),
nn.BatchNorm2d(8),
nn.ReflectionPad2d(1),
nn.Conv2d(8, 8, kernel_size=3),
nn.ReLU(inplace=True),
nn.BatchNorm2d(8),
)
self.fc1 = nn.Sequential(
nn.Linear(8*100*100, 500),
nn.ReLU(inplace=True),
nn.Linear(500, 500),
nn.ReLU(inplace=True),
nn.Linear(500, 5))
def forward_once(self, x):
output = self.cnn1(x)
output = output.view(output.size()[0], -1)
output = self.fc1(output)
return output
def forward(self, input1, input2):
output1 = self.forward_once(input1)
output2 = self.forward_once(input2)
return output1, output2
В настоящее время я пытаюсь найти какую-то существующую реализацию, которую я нашел в Интернете, как и приведенное выше определение класса. Это работает, но для этой модели всегда будет два входа и два выхода. Я согласен, что это удобно для обучения, но в идеале, это должен быть только один вход и один (два тоже хорошо) вывод во время логического вывода.
Может кто-нибудь дать некоторые рекомендации о том, как изменить код, чтобы сделать его один вход?