Реализовать сиамские нейронные сети в PyTorch так же просто, как дважды вызвать сетевую функцию на разных входах.
mynet = torch.nn.Sequential(
nn.Linear(10, 512),
nn.ReLU(),
nn.Linear(512, 2))
...
output1 = mynet(input1)
output2 = mynet(input2)
...
loss.backward()
При вызове loss.backwad()
PyTorch автоматически суммирует градиенты, поступающие из двух вызовов mynet
.
Вы можете найти полноценный пример здесь .