В Pytorch вы можете жестко закодировать ваши фильтры так, как вам нравится.
В данный момент я занимаюсь обнаружением текста, и мне нужно определить местоположение определенной информации. Эта информация всегда начинается с буквы « X ». Может ли это радикально улучшить производительность обнаружения, если я жестко закодирую фильтр ' X '? Вот что у меня есть:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
kernel = (torch.zeros((9, 9)) + \
torch.eye(9) + \
torch.rot90(torch.eye(9))).type(torch.bool)*1
print(kernel)
tensor([[1, 0, 0, 0, 0, 0, 0, 0, 1],
[0, 1, 0, 0, 0, 0, 0, 1, 0],
[0, 0, 1, 0, 0, 0, 1, 0, 0],
[0, 0, 0, 1, 0, 1, 0, 0, 0],
[0, 0, 0, 0, 1, 0, 0, 0, 0],
[0, 0, 0, 1, 0, 1, 0, 0, 0],
[0, 0, 1, 0, 0, 0, 1, 0, 0],
[0, 1, 0, 0, 0, 0, 0, 1, 0],
[1, 0, 0, 0, 0, 0, 0, 0, 1]])
Мы можем визуализировать это так:
plt.imshow(kernel)
plt.show()
Затем мы можем установить вес фильтра следующим образом:
conv = nn.Conv2d(in_channels=1,
out_channels=1,
kernel_size=3,
stride=3,
bias=None)
conv.weight.data = kernel