Может ли кто-нибудь помочь мне понять, почему мои keras для преобразования pytorch не работают? - PullRequest
0 голосов
/ 06 августа 2020

Я пытаюсь перейти с keras на pytorch и получаю непоследовательное поведение. Прежде всего, я замечаю, что у меня другое количество параметров, но еще я замечаю, что обучение keras сходится намного быстрее и плавнее, чем pytorch. Я сделал что-то не так? Кто-нибудь может помочь? Соответствующий код приведен ниже.

У меня есть следующая простая архитектура в keras:

def identity_block(input_tensor, units):
   x = layers.Dense(units, kernel_regularizer=reg)(input_tensor)
   x = layers.BatchNormalization()(x)
   x = layers.Activation('relu')(x)

   x = layers.Dense(units, kernel_regularizer=reg)(x)
   x = layers.BatchNormalization()(x)
   x = layers.Activation('relu')(x)

   x = layers.Dense(units, kernel_regularizer=reg)(x)
   x = layers.BatchNormalization()(x)
   x = layers.add([x, input_tensor])
   x = layers.Activation('relu')(x)
   return x


def dens_block(input_tensor, units, reps=2):
    x = input_tensor

    for _ in range(reps):
        x = layers.Dense(units, kernel_regularizer=reg)(x)
        x = layers.BatchNormalization()(x)
        x = layers.Activation('relu')(x)

    x = layers.Dense(units, kernel_regularizer=reg)(x)
    x = layers.BatchNormalization()(x)
    shortcut = layers.Dense(units, kernel_regularizer=reg)(input_tensor)
    shortcut = layers.BatchNormalization()(shortcut)
    x = layers.add([x, shortcut])
    x = layers.Activation('relu')(x)
    return x

def resnet_block(input_tensor, width=16):
    x = dens_block(input_tensor, width)
    x = identity_block(x, width)
    x = identity_block(x, width)
    return x

def RegResNet(input_size=8, reps=3, initial_weights=None, width=16, num_gpus=0, lr=1e-4):
    input_layer = layers.Input(shape=(input_size,))
    x = input_layer
    for _ in range(reps):
        x = resnet_block(x, width)

    x = layers.BatchNormalization()(x)
    x = layers.Dense(1, activation=None)(x)

    model = models.Model(inputs=input_layer, outputs=x)
    return model

Что дает сводку, которую я не смог уместить, и следующее количество параметров.

==================================================================================================
Total params: 9,905
Trainable params: 8,913
Non-trainable params: 992
==================================================================================================

Ниже мой перевод той же архитектуры на pytorch:

class _DenseBlock(nn.Module):
        def __init__(self, input_size, output_size):
            super(_DenseBlock, self).__init__()
            self.input_size = input_size
            self.output_size = output_size
            self.linear = nn.Linear(self.input_size, self.output_size)
            self.bnorm = nn.BatchNorm1d(self.output_size)

    def forward(self, x):
        x = self.linear(x)
        x = self.bnorm(x)
        x = nn.ReLU()(x)
        return x

class DenseBlock(nn.Module):
    def __init__(self, input_size, output_size, repetitions=2):
        super(DenseBlock, self).__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.repetitions = repetitions
        self.dense_blocks = nn.ModuleList([_DenseBlock(self.input_size, self.output_size)]
                      + [_DenseBlock(self.output_size, self.output_size)
                         for _ in range(self.repetitions - 1)])
        self.b_norm1 = nn.BatchNorm1d(self.output_size)
        self.b_norm2 = nn.BatchNorm1d(self.output_size)
        self.linear_1 = nn.Linear(self.output_size, self.output_size)
        self.linear_2 = nn.Linear(self.input_size, self.output_size)



    def forward(self, x):
        identity = x
        
        for l in self.dense_blocks:
            x = l(x)

        x = self.linear_1(x)
        x = self.b_norm1(x)

        shortcut = self.linear_2(identity)
        shortcut = self.b_norm2(shortcut)

        x = shortcut + x
        x = nn.ReLU()(x)
        return x

class IdentityBlock(nn.Module):
    def __init__(self, input_size, output_size):
        super(IdentityBlock, self).__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.linear_1 = nn.Linear(self.input_size, self.output_size)
        self.linear_2 = nn.Linear(self.input_size, self.output_size)
        self.linear_3 = nn.Linear(self.input_size, self.output_size)
        self.bnorm_1 = nn.BatchNorm1d(self.input_size)
        self.bnorm_2 = nn.BatchNorm1d(self.output_size)
        self.bnorm_3 = nn.BatchNorm1d(self.output_size)

    def forward(self, x):
        input_tensor = x
        x = self.linear_1(x)
        x = self.bnorm_1(x)
        x = nn.ReLU()(x)

        x = self.linear_2(x)
        x = self.bnorm_2(x)
        x = nn.ReLU()(x)

        x = self.linear_3(x)
        x = self.bnorm_3(x)
        x = x + input_tensor
        x = nn.ReLU()(x)
        return x

class ResnetBlock(nn.Module):
    def __init__(self, input_size, output_size):
        super(ResnetBlock, self).__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.dense_block = DenseBlock(self.input_size, self.output_size)
        self.identity_block_1 = IdentityBlock(self.output_size, self.output_size)
        self.identity_block_2 = IdentityBlock(self.output_size, self.output_size)

    def forward(self, x):
        x = self.dense_block(x)
        x = self.identity_block_1(x)
        x = self.identity_block_2(x)
        return x

class RegResNet(nn.Module):

    def __init__(self, input_size, block_width, repititions=3):
        super(RegResNet, self).__init__()
        self.input_size = input_size
        self.repititions = repititions
        self.block_width = block_width
        self.resnet_blocks = nn.ModuleList([ResnetBlock(self.input_size, self.block_width)]
                                           + [ResnetBlock(self.block_width, self.block_width)
                                              for _ in range(self.repititions - 1)])
        self.bnorm = nn.BatchNorm1d(self.block_width)
        self.out = nn.Linear(self.block_width, 1)

    def forward(self, x):
        for layer in self.resnet_blocks:
            x = layer(x)

        x = self.bnorm(x)
        x = self.out(x)
        return x

Используя библиотеку torchsummary, я получил сводку (которую я не смог уместить) и следующее количество параметров

================================================================
    Total params: 8,913
    Trainable params: 8,913
    Non-trainable params: 0
================================================================
...