Какой-нибудь пример факела 0.4.0 nn.LayerNorm для nn.LSTMCell? - PullRequest
0 голосов
/ 03 мая 2018

В выпуске Pytorch 0.4.0 есть модуль nn.LayerNorm .

Я хочу внедрить этот уровень в мою сеть LSTM, хотя пока не могу найти ни одного примера реализации в сети LSTM.

И автор pytorch подразумевает, что это nn.LayerNorm применимо только через nn.LSTMCell s.

Будет очень полезно, если я смогу получить любое git-репо или какой-нибудь код, который реализует nn.LayerNorm на nn.LSTMcell или любую сеть LSTM с факелами.

Заранее спасибо

1 Ответ

0 голосов
/ 05 ноября 2018

Я тоже ищу решение. Вот пример из https://github.com/pytorch/pytorch/issues/11335
Благодаря @ jinserk * ​​1003 *

class LayerNormLSTMCell(nn.LSTMCell):

def __init__(self, input_size, hidden_size, bias=True):
    super().__init__(input_size, hidden_size, bias)

    self.ln_ih = nn.LayerNorm(4 * hidden_size)
    self.ln_hh = nn.LayerNorm(4 * hidden_size)
    self.ln_ho = nn.LayerNorm(hidden_size)

def forward(self, input, hidden=None):
    self.check_forward_input(input)
    if hidden is None:
        hx = input.new_zeros(input.size(0), self.hidden_size, requires_grad=False)
        cx = input.new_zeros(input.size(0), self.hidden_size, requires_grad=False)
    else:
        hx, cx = hidden
    self.check_forward_hidden(input, hx, '[0]')
    self.check_forward_hidden(input, cx, '[1]')

    gates = self.ln_ih(F.linear(input, self.weight_ih, self.bias_ih)) \
             + self.ln_hh(F.linear(hx, self.weight_hh, self.bias_hh))
    i, f, o = gates[:, :(3 * self.hidden_size)].sigmoid().chunk(3, 1)
    g = gates[:, (3 * self.hidden_size):].tanh()

    cy = (f * cx) + (i * g)
    hy = o * self.ln_ho(cy).tanh()
    return hy, cy


class LayerNormLSTM(nn.Module):

def __init__(self, input_size, hidden_size, num_layers=1, bias=True, bidirectional=False):
    super().__init__()
    self.input_size = input_size
    self.hidden_size = hidden_size
    self.num_layers = num_layers
    self.bidirectional = bidirectional

    num_directions = 2 if bidirectional else 1
    self.hidden0 = nn.ModuleList([
        LayerNormLSTMCell(input_size=(input_size if layer == 0 else hidden_size * num_directions),
                          hidden_size=hidden_size, bias=bias)
        for layer in range(num_layers)
    ])

    if self.bidirectional:
        self.hidden1 = nn.ModuleList([
            LayerNormLSTMCell(input_size=(input_size if layer == 0 else hidden_size * num_directions),
                              hidden_size=hidden_size, bias=bias)
            for layer in range(num_layers)
        ])

def forward(self, input, hidden=None):
    seq_len, batch_size, hidden_size = input.size()  # supports TxNxH only
    num_directions = 2 if self.bidirectional else 1
    if hidden is None:
        hx = input.new_zeros(self.num_layers * num_directions, batch_size, self.hidden_size, requires_grad=False)
        cx = input.new_zeros(self.num_layers * num_directions, batch_size, self.hidden_size, requires_grad=False)
    else:
        hx, cx = hidden

    ht = [[None, ] * (self.num_layers * num_directions)] * seq_len
    ct = [[None, ] * (self.num_layers * num_directions)] * seq_len

    if self.bidirectional:
        xs = input
        for l, (layer0, layer1) in enumerate(zip(self.hidden0, self.hidden1)):
            l0, l1 = 2 * l, 2 * l + 1
            h0, c0, h1, c1 = hx[l0], cx[l0], hx[l1], cx[l1]
            for t, (x0, x1) in enumerate(zip(xs, reversed(xs))):
                ht[t][l0], ct[t][l0] = layer0(x0, (h0, c0))
                h0, c0 = ht[t][l0], ct[t][l0]
                t = seq_len - 1 - t
                ht[t][l1], ct[t][l1] = layer1(x1, (h1, c1))
                h1, c1 = ht[t][l1], ct[t][l1]
            xs = [torch.cat((h[l0], h[l1]), dim=1) for h in ht]
        y  = torch.stack(xs)
        hy = torch.stack(ht[-1])
        cy = torch.stack(ct[-1])
    else:
        h, c = hx, cx
        for t, x in enumerate(input):
            for l, layer in enumerate(self.hidden0):
                ht[t][l], ct[t][l] = layer(x, (h[l], c[l]))
                x = ht[t][l]
            h, c = ht[t], ct[t]
        y  = torch.stack([h[-1] for h in ht])
        hy = torch.stack(ht[-1])
        cy = torch.stack(ct[-1])

    return y, (hy, cy)
...