LSTM зацикливается - PullRequest
       10

LSTM зацикливается

1 голос
/ 02 июля 2019

Я недавно внедрил имя, генерирующее RNN "с нуля", которое работало нормально, но далеко не идеально.Поэтому я подумал о том, чтобы попытать счастья с классом LSTM от pytorch, чтобы понять, имеет ли это значение.Действительно, это так, и выход выглядит лучше для первых 7 ~ 8 символов.Но затем сети зацикливаются и выводят такие вещи, как "laulaulaulau" или "rourourourou" (предполагается, что они генерируют французские имена).

Это часто встречающаяся проблема?Если да, то знаете ли вы, как это исправить?Я обеспокоен тем фактом, что сеть не производит токены EOS ... Это проблема, которая уже была задана здесь Почему моя модель keras LSTM застревает в бесконечном цикле? , но на самом деле это не такотсюда и ответил мой пост.

вот модель:

class pytorchLSTM(nn.Module):
    def __init__(self,input_size,hidden_size):
        super(pytorchLSTM,self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.lstm = nn.LSTM(input_size, hidden_size)
        self.output_layer = nn.Linear(hidden_size,input_size)
        self.tanh = nn.Tanh()
        self.softmax = nn.LogSoftmax(dim = 2)

    def forward(self, input, hidden)
            out, hidden = self.lstm(input,hidden)
            out = self.tanh(out)
            out = self.output_layer(out)
            out = self.softmax(out)
        return out, hidden

Вход и цель - две последовательности векторов с горячим кодированием соответственно с началом последовательности и концом вектора последовательности вначало и конец.Они представляют символы внутри имени, взятые из списка имен (база данных).Я использую и токен на каждое имя из базы данных.вот функция, которую я использую

def inputTensor(line):
#tensor starts with <start of sequence> token.
    tensor = torch.zeros(len(line)+1, 1, n_letters)
    tensor[0][0][n_letters - 2] = 1
    for li in range(len(line)):
        letter = line[li]
        tensor[li+1][0][all_letters.find(letter)] = 1
    return tensor

# LongTensor of second letter to end (EOS) for target
def targetTensor(line):
    letter_indexes = [all_letters.find(line[li]) for li in range(len(line))]
    letter_indexes.append(n_letters - 1) # EOS
    return torch.LongTensor(letter_indexes)

обучающий цикл:


def train_lstm(model):
    start = time.time()
    criterion = nn.NLLLoss()
    optimizer = torch.optim.Adam(model.parameters())
    n_iters = 20000
    print_every = 1000
    plot_every = 500
    all_losses = []
    total_loss = 0
    for iter in range(1,n_iters+1):
        line = randomChoice(category_line)
        input_line_tensor = inputTensor(line)
        target_line_tensor = targetTensor(line).unsqueeze(-1)
        optimizer.zero_grad()       
        loss = 0
        output, hidden = model(input_line_tensor)
        for i in range(input_line_tensor.size(0)):
            l = criterion(output[i], target_line_tensor[i])
            loss += l
        loss.backward()
        optimizer.step() 

функция выборки:

def sample():
    max_length = 20
    input = torch.zeros(1,1,n_letters)
    input[0][0][n_letters - 2] = 1
    output_name = ""
    hidden = (torch.zeros(2,1,lstm.hidden_size),torch.zeros(2,1,lstm.hidden_size)) 

    for i in range(max_length):
        output, hidden = lstm(input)
        output = output[-1][:][:]
        l = torch.multinomial(torch.exp(output[0]),num_samples = 1).item()
        if l == n_letters - 1:
            break
        else:
            letter = all_letters[l]
            output_name += letter
        input = inputTensor(letter)
    return output_name

Типичный сэмплированный вывод выглядит примерно так:

Laurayeerauerararauo
Leayealouododauodouo
Courouauurourourodau

Знаете, как я могу это улучшить?

Ответы [ 2 ]

2 голосов
/ 04 июля 2019

Я нашел объяснение:

При использовании экземпляров класса LSTM в составе RNN входные размеры по умолчанию: (seq_length,batch_dim,input_size).Чтобы иметь возможность интерпретировать вывод lstm как вероятность (через набор входных данных), мне нужно было передать его на уровень Linear перед вызовом Softmax, в котором и возникает проблема: Linear ожидают экземплярывходные данные должны иметь формат (batch_dim,seq_length,input_size).

. Чтобы исправить это, необходимо передать batch_first = True в качестве аргумента LSTM при создании, а затем передать RNN вход в виде(batch_dim, seq_length, input_size).

0 голосов
/ 02 июля 2019

Несколько советов по улучшению сети в порядке важности (и простота внедрения):

1.Обучающие данные

Если вы хотите, чтобы сгенерированные образцы выглядели реальными, вам нужно предоставить некоторые реальные данные в сеть.Найдите набор имен, разбейте их на буквы и преобразуйте в индексы.Один этот шаг дал бы более реалистичные имена.

2.Отдельные начальные и конечные токены.

Я бы пошел с <SON> (начало имени) и <EON> (конец имени).В этой конфигурации нейронная сеть может изучать комбинации букв, ведущих к <EON>, и комбинации букв, следующих за <SON>.ATM пытается втиснуть две разные концепции в один пользовательский токен.

3.Неподдерживаемое сохранение

Возможно, вы захотите придать вашим буквам некоторое семантическое значение вместо закодированных в одно горячее состояние векторов, проверьте word2vec для базового подхода.

По сути, каждая буква будетпредставлен N -мерным вектором (скажем, 50-мерным) и был бы ближе в пространстве, если буква встречается чаще рядом с другой буквой (a ближе к k, чем x).

Простой способ реализовать это - взять набор текстовых данных и попытаться предсказать следующую букву на каждом временном шаге.Каждая буква будет представлена ​​случайным вектором в начале, при этом обратное распространение буквенных представлений будет обновлено, чтобы отразить их сходство.

Проверьте учебник по встраиванию pytorch для получения дополнительной информации.

4.Другая архитектура

Возможно, вы захотите проверить идею Андрея Карпати о создании имен для детей.Это просто описано здесь .

По сути, после обучения вы кормите свою модель случайными буквами (скажем, 10) и говорите ей, чтобы предсказать следующую букву.

Вы удаляете последнюю букву из случайного начального числа и помещаете предсказанную на ее место.Итерируйте, пока не будет выведено <EON>.

...