Подача трехмерного медицинского изображения в двухмерный сверточный LSTM, по частям в PyTorch - PullRequest
0 голосов
/ 25 марта 2020

Я делаю сегментацию изображения. Я хочу обработать размер среза (глубину) 3D-изображения CT как размер времени или последовательности и использовать рекуррентную сеть на наборе 2D-срезов для захвата зависимостей.

Я использую Сверточный LSTM для использования LSTM для изображений. Точнее, мне нужен двунаправленный сверточный LSTM. Я использую convolutional_rnn.Conv2dLSTM() из упомянутого источника, который я не уверен, является лучшим вариантом или нет.

Теперь вопрос в том, как я могу это реализовать. В настоящее время это моя реализованная модель, которая, по-видимому, неверна (в PyTorch 1.3),

class model(nn.Module):
    def __init__(self, n_in_channels=1, CNN_out_ch=16):
        super().__init__()
        self.cnn = MyCNN(n_in_channels=n_in_channels)
        self.CNN_out_ch = CNN_out_ch
        self.rnn = convolutional_rnn.Conv2dLSTM(in_channels=CNN_out_ch, out_channels=CNN_out_ch,
                                   kernel_size=5, num_layers=2, bidirectional=True,
                                   dilation=1, dropout=0.5, batch_first=True)
        self.dropout = nn.Dropout(0.5)
        self.pool = nn.MaxPool2d(kernel_size=2)


    def forward(self, input_tensor):
        image_stack = torch.zeros((input_tensor.shape[0], input_tensor.shape[2],
                                  self.CNN_out_ch, input_tensor.shape[3], input_tensor.shape[4]))
        # CNN part
        for i in range(input_tensor.shape[2]):
            slice_output = self.cnn(input_tensor[:,:,i,:,:])
            image_stack[:,i] = slice_output

        # Recurrent part
        image_stack = self.dropout(image_stack)
        image_stack, _ = self.rnn1(image_stack)

        image_stack_new = []
        for i in range(image_stack.shape[1]):
            image_stack_new.append(self.pool(image_stack[:,i]))
        image_stack = torch.stack(image_stack_new, dim=1)

My input_tensor для forward() имеет значение shape(batch, channel, depth or slice, height, width). Сначала я передаю изображение в al oop по размеру среза в 2D CNN (последний полностью связанный слой в 2D CNN удаляется, и выводятся карты объектов с 16 каналами).

Затем снова соберите кусочки и измените их до (batch, depth, channel, height, width), как в LSTM, у нас есть измерение времени после пакета. Затем я применяю выпадение, а затем BiCLSTM.

Наконец, снова в al oop для измерения среза, примените 2D пул к каждому срезу.

Может кто-нибудь уловить проблему здесь?

...