Можно ли жестко закодировать сверточные фильтры для обнаружения символов в CNN? - PullRequest
0 голосов
/ 03 февраля 2020

В Pytorch вы можете жестко закодировать ваши фильтры так, как вам нравится.

В данный момент я занимаюсь обнаружением текста, и мне нужно определить местоположение определенной информации. Эта информация всегда начинается с буквы « X ». Может ли это радикально улучшить производительность обнаружения, если я жестко закодирую фильтр ' X '? Вот что у меня есть:

import torch
import torch.nn as nn
import matplotlib.pyplot as plt

kernel = (torch.zeros((9, 9)) + \
          torch.eye(9) + \
          torch.rot90(torch.eye(9))).type(torch.bool)*1

print(kernel)
tensor([[1, 0, 0, 0, 0, 0, 0, 0, 1],
        [0, 1, 0, 0, 0, 0, 0, 1, 0],
        [0, 0, 1, 0, 0, 0, 1, 0, 0],
        [0, 0, 0, 1, 0, 1, 0, 0, 0],
        [0, 0, 0, 0, 1, 0, 0, 0, 0],
        [0, 0, 0, 1, 0, 1, 0, 0, 0],
        [0, 0, 1, 0, 0, 0, 1, 0, 0],
        [0, 1, 0, 0, 0, 0, 0, 1, 0],
        [1, 0, 0, 0, 0, 0, 0, 0, 1]])

Мы можем визуализировать это так:

plt.imshow(kernel)
plt.show()

enter image description here

Затем мы можем установить вес фильтра следующим образом:

conv = nn.Conv2d(in_channels=1, 
                 out_channels=1, 
                 kernel_size=3, 
                 stride=3, 
                 bias=None)

conv.weight.data = kernel
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...