Как реализовать текущие функции активации pytorch с параметрами? - PullRequest
0 голосов
/ 14 января 2019

Я ищу простой способ использования функции активации, которая существует в библиотеке pytorch, но использует какой-то параметр. например:

Tanh (х / 10)

Единственный способ, которым я пришел к поиску решения, - полностью реализовать пользовательскую функцию с нуля. Есть ли лучший / более элегантный способ сделать это?

редактирование:

Я ищу способ добавить в мою модель функцию Tanh (x / 10), а не просто Tanh (x). Вот соответствующий блок кода:

    self.model = nn.Sequential()
    for i in range(len(self.layers)-1):
        self.model.add_module("linear_layer_" + str(i), nn.Linear(self.layers[i], self.layers[i + 1]))
        if activations == None:
            self.model.add_module("activation_" + str(i), nn.Tanh())
        else:
            if activations[i] == "T":
                self.model.add_module("activation_" + str(i), nn.Tanh())
            elif activations[i] == "R":
                self.model.add_module("activation_" + str(i), nn.ReLU())
            else:
                #no activation
                pass

Ответы [ 2 ]

0 голосов
/ 14 января 2019

Вы можете создать слой с параметром умножения:

import torch
import torch.nn as nn

class CustomTanh(nn.Module):

    #the init method takes the parameter:
    def __init__(self, multiplier):
        self.multiplier = multiplier

    #the forward calls it:
    def forward(self, x):
        x = self.multiplier * x
        return torch.tanh(x)

Добавьте его к своим моделям с CustomTanh(1/10) вместо nn.Tanh().

0 голосов
/ 14 января 2019

Вместо того, чтобы определять его как определенную функцию, вы можете встроить его в пользовательский слой .

Например, ваше решение может выглядеть так:


import torch
import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(4, 10)
        self.fc2 = nn.Linear(10, 3)
        self.fc3 = nn.Softmax()

    def forward(self, x):
        return self.fc3(self.fc2(torch.tanh(self.fc1(x)/10)))

, где torch.tanh(output/10) встроено в функцию пересылки вашего модуля.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...