Прогнозирование следующего трека с использованием ванильного RNN в PyTorch - PullRequest
0 голосов
/ 12 апреля 2020

Для некоторого контекста у меня есть набор из 37 списков воспроизведения длиной 12 треков. Каждый трек был выбран вручную определенным образом. Ранние песни в плейлисте, как правило, более охлаждены, и по мере продвижения плейлиста треки начинают увеличиваться в темпе. Я решил посвятить себя проекту и создать глубокий генератор плейлистов.

Я реализую ванильный RNN «многие ко многим» в PyTorch и ищу ясности в том, как обучать RNN по одному плейлисту за раз, где затем каждая дорожка анализируется, и модель прогнозирует характеристики следующей дорожки.

На рисунке изображен RNN типа «многие ко многим» - в данном случае - каждая красная рамка - это особенности текущей дорожки, а противоположная синяя рамка - предсказанные функции следующего трека:

Many-to-many RNN

Набор функций (9), X, выглядит так: Spotify track features

Цель y просто отражает вышеуказанный набор функций следующей дорожки.

Для моего класса RNN это выглядит так:

class RNNEstimator(nn.Module):
    def __init__(self, input_size=9, hidden_size=30, output_size=9):
        super(RNNEstimator, self).__init__()

        self.hidden_size = hidden_size

        self.i2h = nn.Linear(input_size + hidden_size, hidden_size)
        self.i2o = nn.Linear(input_size + hidden_size, output_size)

    def forward(self, inp, hidden):
        print("inp", inp.shape)
        print("hid", hidden.shape)
        combined = torch.cat((inp, hidden), 1)
        hidden = self.i2h(combined)
        output = self.i2o(combined)
        return output, hidden

    def initHidden(self):
        return torch.zeros(1, self.hidden_size)

Это взято из Страница учебных пособий по PyTorch . Однако я адаптировал класс RNN для вывода функций 9, а не двоичной классификации.

Набор данных списка воспроизведения был преобразован в тензор формы факела. Размер ([37, 12, 18]) и шаг (12, 1, 444)) - имеется в виду 37 списков воспроизведения, 12 длинных треков с 9 X_features + 9 y_features (18).

Функция train_rnn:

# Model Initiation
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = RNNEstimator(9, 30, 9)
optimizer = optim.Adam(model.parameters(), lr=0.001)
loss_fn = torch.nn.L1Loss()

# Training function for RNN
def train_rnn(model, train_loader, epochs, criterion, optimizer, device):

    model.train() # Make sure that the model is in training mode.

    # training loop is provided
    for epoch in range(1, epochs + 1):

        for batch in train_loader:

            total_loss = 0

            # get data
            batch_x = batch[:, :9, :].float().squeeze()
            batch_y = batch[:, 9:, :].float()

            batch_x = batch_x.to(device)
            batch_y = batch_y.to(device)

            optimizer.zero_grad()

            hidden = model.initHidden()       

            # For each track in batch/playlist
            # TODO: THIS NEEDS WORK
            for x, y in zip(batch_x, batch_y):
                output, hidden = model(x, hidden)
                loss = criterion(output, y)
                loss.backward()
                optimizer.step()
                total_loss += loss.data.item()

        if epoch % 10 == 0:
            print('Epoch: {}/{}.............'.format(epoch, epochs), end=' ')
            print("Loss: {:.4f}".format(loss.item()))

Что я пытаюсь сделать Понимаете, как обучить эту модель одному списку воспроизведения в пакете?

Я получаю сообщение об ошибке от функции cat, например:

IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)

Модель должна анализировать каждую дорожку (t) - через прямой метод - затем вывести следующий трек (t + 1). Скрытое состояние будет сбрасывать каждый список воспроизведения, если они независимы друг от друга.

1 Ответ

1 голос
/ 13 апреля 2020

Даже если вы хотите запустить с размером пакета 1, ваш ввод (x) все еще нуждается в измерении пакета. Попробуйте:

output, hidden = model(x.unsqueeze(0), hidden)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...