OCR в pytorch / keras с LP RNet и CTCLoss не сходятся - PullRequest
2 голосов
/ 03 мая 2020

Я реализовал LP RNet из этой бумаги в кератах, а затем в pytorch. Для pytorch был использован этот код в качестве примера. Это не сходилось ни в одной версии. Потеря начинается с 30 и продолжается до 27 даже после 150000 эпох. В версии pytorch я не использую валидацию для более быстрого обучения.

то, что я пробовал:

  • инициализировал веса conv2d с помощью xavier_uniform_, xavier_normal и всех других, которые я нашел в torch.nn. init
  • использовал другой набор данных с 50 чистыми изображениями, которые использовали другие ребята для своей модели rcnn ocr
  • попробовал sgd вместо adam

полный колаб с pytorch равен здесь и мой набор данных здесь

В керасе мне когда-то удалось одеть модель поверх тренировочного набора. В моем тренировочном наборе всего 50 изображений, но я применяю случайный эффект, чтобы не перегружать. В своей статье авторы не указали, какой большой набор данных они использовали. Может быть проблема в реализации модели? Кроме того, в оригинальной статье содержалась пространственная трансформаторная сеть, которую я не использовал, потому что они сказали, что это необязательно, но может ли это быть проблемой?

class Softmax(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        #se aplica pt torch.Size([1, 30, 1, 74])
        return torch.nn.functional.log_softmax(x, 3)

class Reshape(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        # se aplica pt torch.Size([1, 30, 1, 74])
        #pentru CTCloss
        #trebuie sa fie (T,N,C) , where T=input length, N=batch size, C=number of classes
        x = x.permute(3, 0, 1, 2)
        return x.view(x.shape[0], x.shape[1], x.shape[2])

def init_weights(m):
    if type(m) == nn.Conv2d:
        torch.nn.init.xavier_uniform_(m.weight)
        #m.bias.data.fill_(0.01)

class small_basic_block(nn.Module):
    def __init__(self, ch_in, ch_out):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(ch_in, ch_out//4, kernel_size=1),
            nn.ReLU(),
            nn.Conv2d(ch_out//4, ch_out//4, kernel_size=(3,1), padding=(1,0)),
            nn.ReLU(),
            nn.Conv2d(ch_out//4, ch_out//4, kernel_size=(1,3), padding=(0,1)),
            nn.Conv2d(ch_out//4, ch_out, kernel_size=1)
        )
        self.block.apply(init_weights)

    def forward(self, x):
        return self.block(x)

class LPRNet(nn.Module):
    def __init__(self, class_num):
        super().__init__()
        self.lprnet = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1),
            nn.BatchNorm2d(num_features=64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=1),
            small_basic_block(ch_in=64, ch_out=128),
            nn.BatchNorm2d(num_features=128),
            nn.ReLU(),
            nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(2, 2, 1)),
            small_basic_block(ch_in=64, ch_out=256),
            nn.BatchNorm2d(num_features=256),
            nn.ReLU(),
            small_basic_block(ch_in=256, ch_out=256),
            nn.BatchNorm2d(num_features=256),
            nn.ReLU(),
            nn.MaxPool3d(kernel_size=(1,3,3), stride=(4,2,1)),
            nn.Dropout(0.5),
            nn.Conv2d(64, 256, kernel_size=(4,1), stride=1),
            nn.BatchNorm2d(num_features=256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Conv2d(256, class_num, kernel_size=(1,13), stride=1),
            nn.BatchNorm2d(num_features=class_num),
            nn.ReLU(), # torch.Size([1, 30, 1, 74])
            Softmax(),
            Reshape()
        )
        self.lprnet.apply(init_weights)

    def forward(self, x):
        return self.lprnet(x)

ctc_loss = nn.CTCLoss(blank=alphabet.index('$'))

model = LPRNet(len(alphabet)+1)
model.to(dev)
opt = optim.Adam(model.parameters())
scheduler = optim.lr_scheduler.StepLR(opt, step_size=100000, gamma=0.1)

def train(epochs):
    for i in range(epochs):
        X_data, Y_data, X_data_len, Y_data_len = train_set.next_batch()
        X_data = model(X_data)
        loss = ctc_loss(X_data, Y_data, X_data_len, Y_data_len)
        loss.backward()
        opt.step()
        opt.zero_grad()
        scheduler.step()
...