PyTorch LSTM: RuntimeError: неверный аргумент 0: размеры тензоров должны совпадать, за исключением измерения 0. Получены 1219 и 440 в измерении 1 - PullRequest
2 голосов
/ 10 февраля 2020

У меня есть основа c PyTorch LSTM:

import torch.nn as nn
import torch.nn.functional as F

class BaselineLSTM(nn.Module):
    def __init__(self):
        super(BaselineLSTM, self).__init__()

        self.lstm = nn.LSTM(input_size=13, hidden_size=13)

    def forward(self, x):
        x = self.lstm(x)

        return x

Для моих данных у меня есть:

train_set = CorruptedAudioDataset(corrupted_path, train_set=True)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=128, shuffle=True, **kwargs)

Мой CorruptedAudioDataset имеет:

    def __getitem__(self, index):
        corrupted_sound_file = SoundFile(self.file_paths[index])
        corrupted_samplerate = corrupted_sound_file.samplerate
        corrupted_signal_audio_array = corrupted_sound_file.read()

        clean_path = self.file_paths[index].split('/')
        # print(self.file_paths[index], clean_path)
        clean_sound_file = SoundFile(self.file_paths[index])
        clean_samplerate = clean_sound_file.samplerate
        clean_signal_audio_array = clean_sound_file.read()


        corrupted_mfcc = mfcc(corrupted_signal_audio_array, samplerate=corrupted_samplerate)
        clean_mfcc = mfcc(clean_signal_audio_array, samplerate=clean_samplerate)


        print('return', corrupted_mfcc.shape, clean_mfcc.shape)
        return corrupted_mfcc, clean_mfcc

Мое обучение l oop выглядит так:

    model = BaselineLSTM()
    for epoch in range(300):
        for inputs, outputs in train_loader:
            print('inputs', inputs)

И вот строка, в которой я получаю сообщение об ошибке:

  File "train_lstm_baseline.py", line 47, in train
    for inputs, outputs in train_loader:
...
RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 1219 and 440 in dimension 1 at ../aten/src/TH/generic/THTensor.cpp:612

Не уверен, что я делаю неправильно. Любая помощь будет оценена. Спасибо!

1 Ответ

2 голосов
/ 10 февраля 2020

Это исключение выдается в основном потому, что вы загружаете партии различной формы. Поскольку они хранятся в одном и том же тензоре, все образцы должны иметь одинаковую форму. В этом случае у вас есть вход в измерении 0 с 1219 и 440, что невозможно. Например, у вас есть что-то вроде:

torch.Size([1, 1219])
torch.Size([1, 440])
torch.Size([1, 550])
...

У вас должно быть:

torch.Size([1, n])
torch.Size([1, n])
torch.Size([1, n])
...

Самый простой способ решить эту проблему - установить batch_size=1. Однако это может задержать ваш код.

Лучший способ - установить данные в одну и ту же форму. В этом случае вам нужно оценить свою проблему, чтобы проверить, возможно ли это.

Надеюсь, это поможет.

...